tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python爬虫_实现校园网自动重连脚本的教程
Apr 22 Python
利用Python将每日一句定时推送至微信的实现方法
Aug 13 Python
python使用PIL模块获取图片像素点的方法
Jan 08 Python
python 寻找离散序列极值点的方法
Jul 10 Python
教你如何编写、保存与运行Python程序的方法
Jul 12 Python
python向图片里添加文字
Nov 26 Python
Python批量启动多线程代码实例
Feb 18 Python
keras模型保存为tensorflow的二进制模型方式
May 25 Python
详解用Python爬虫获取百度企业信用中企业基本信息
Jul 02 Python
python 浮点数四舍五入需要注意的地方
Aug 18 Python
python实现图像高斯金字塔的示例代码
Dec 11 Python
OpenCV3.3+Python3.6实现图片高斯模糊
May 18 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
symfony2.4的twig中date用法分析
2016/03/18 PHP
JavaScript脚本性能的优化方法
2007/02/02 Javascript
javascript 函数使用说明
2010/04/07 Javascript
Document:getElementsByName()使用方法及示例
2013/10/28 Javascript
封装了一个支持匿名函数的Javascript事件监听器
2014/06/05 Javascript
js监听鼠标点击和键盘点击事件并自动跳转页面
2014/09/24 Javascript
PHP结合jQuery实现的评论顶、踩功能
2015/07/22 Javascript
基于html5和nodejs相结合实现websocket即使通讯
2015/11/19 NodeJs
jquery easyui datagrid实现增加,修改,删除方法总结
2016/05/25 Javascript
JavaScript 控制字体大小设置的方法
2016/11/23 Javascript
利用JavaScript在网页实现八数码启发式A*算法动画效果
2017/04/16 Javascript
JavaScript脚本语言是什么_动力节点Java学院整理
2017/06/26 Javascript
Koa2微信公众号开发之消息管理
2018/05/16 Javascript
node.js遍历目录的方法示例
2018/08/01 Javascript
详解Vue开发微信H5微信分享签名失败问题解决方案
2018/08/09 Javascript
koa2实现登录注册功能的示例代码
2018/12/03 Javascript
简述Vue中容易被忽视的知识点
2019/12/09 Javascript
Vue使用JSEncrypt实现rsa加密及挂载方法
2020/02/07 Javascript
Vue中 axios delete请求参数操作
2020/08/25 Javascript
python批量修改文件后缀示例代码分享
2013/12/24 Python
Python SQLAlchemy基本操作和常用技巧(包含大量实例,非常好)
2014/05/06 Python
简单总结Python中序列与字典的相同和不同之处
2016/01/19 Python
Python Socket TCP双端聊天功能实现过程详解
2020/06/15 Python
python3.9和pycharm的安装教程并创建简单项目的步骤
2021/02/03 Python
手工制作的男士奢华英国鞋和服装之家:Goodwin Smith
2019/06/21 全球购物
荣耀俄罗斯官网:HONOR俄罗斯
2020/10/31 全球购物
异步传递消息系统的作用
2016/05/01 面试题
机电专业个人求职信范文
2013/12/30 职场文书
职业生涯规划书前言
2014/04/15 职场文书
《鹬蚌相争》教学反思
2014/04/22 职场文书
先进党员事迹材料
2014/12/24 职场文书
公司租车协议书
2015/01/29 职场文书
档案管理员岗位职责
2015/02/12 职场文书
会计工作自我鉴定范文
2019/06/21 职场文书
十大最强格斗系宝可梦,超梦X仅排第十,第二最重格斗礼仪
2022/03/18 日漫
【海涛教你打dota】体验一超神发条:咱是抢盾专业户
2022/04/01 DOTA