tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
下载糗事百科的内容_python版
Dec 07 Python
python代码制作configure文件示例
Jul 28 Python
python中的字典使用分享
Jul 31 Python
浅析python中SQLAlchemy排序的一个坑
Feb 24 Python
python添加模块搜索路径方法
Sep 11 Python
pyQt4实现俄罗斯方块游戏
Jun 26 Python
详解Python正则表达式re模块
Mar 19 Python
利用python实现AR教程
Nov 20 Python
使用Python实现牛顿法求极值
Feb 10 Python
Python定时从Mysql提取数据存入Redis的实现
May 03 Python
python模块如何查看
Jun 16 Python
Python requests及aiohttp速度对比代码实例
Jul 16 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
利用ThinkPHP内置的ThinkAjax实现异步传输技术的实现方法
2011/12/19 PHP
RR vs IO BO3 第二场2.13
2021/03/10 DOTA
js身份证验证超强脚本
2008/10/26 Javascript
JQuery Dialog(JS 模态窗口,可拖拽的DIV)
2010/02/07 Javascript
jQuery getJSON 处理json数据的代码
2010/07/26 Javascript
修改file按钮的默认样式实现代码
2013/04/23 Javascript
jquery实现的可隐藏重现的靠边悬浮层实例代码
2013/05/27 Javascript
判断js对象是否拥有某一个属性的js代码
2013/08/16 Javascript
NodeJs中的VM模块详解
2015/05/06 NodeJs
跟我学习javascript的隐式强制转换
2015/11/16 Javascript
AngularJS实现分页显示数据库信息
2016/07/01 Javascript
vue同步父子组件和异步父子组件的生命周期顺序问题
2018/10/07 Javascript
详解vuex状态管理模式
2018/11/01 Javascript
使用mpvue搭建一个初始小程序及项目配置方法
2018/12/03 Javascript
详解微信小程序scroll-view横向滚动的实践踩坑及隐藏其滚动条的实现
2019/03/14 Javascript
Vue 实现前进刷新后退不刷新的效果
2019/06/14 Javascript
Vue混入mixins滚动触底的方法
2019/11/22 Javascript
Python虚拟环境virtualenv的安装与使用详解
2017/05/28 Python
Python实现的矩阵类实例
2017/08/22 Python
用python处理MS Word的实例讲解
2018/05/08 Python
Python 删除连续出现的指定字符的实例
2018/06/29 Python
Python比较配置文件的方法实例详解
2019/06/06 Python
python运用sklearn实现KNN分类算法
2019/10/16 Python
详解matplotlib绘图样式(style)初探
2021/02/03 Python
什么是WEB控件?使用WEB控件有哪些优势?
2012/01/21 面试题
50道外企软件测试面试题
2014/08/18 面试题
车间组长岗位职责
2013/12/20 职场文书
优秀社区干部事迹材料
2014/02/03 职场文书
电厂职工自我鉴定
2014/02/20 职场文书
《鲁班和橹板》教学反思
2014/04/27 职场文书
2014国庆黄金周超市促销活动方案
2014/09/21 职场文书
公安机关正风肃纪剖析材料
2014/10/10 职场文书
2014小学语文教师个人工作总结
2014/12/03 职场文书
综合管理员岗位职责
2015/02/11 职场文书
校运会班级霸气口号
2015/12/24 职场文书
golang 生成对应的数据表struct定义操作
2021/04/28 Golang