tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3中多线程编程的队列运作示例
Apr 16 Python
详解在Python程序中解析并修改XML内容的方法
Nov 16 Python
浅谈scrapy 的基本命令介绍
Jun 13 Python
python的多重继承的理解
Aug 06 Python
Python从Excel中读取日期一列的方法
Nov 28 Python
Python 计算任意两向量之间的夹角方法
Jul 05 Python
Django 全局的static和templates的使用详解
Jul 19 Python
Django 实现前端图片压缩功能的方法
Aug 07 Python
python使用建议与技巧分享(二)
Aug 17 Python
Python可以用来做什么
Nov 23 Python
Python 实现进度条的六种方式
Jan 06 Python
Python借助with语句实现代码段只执行有限次
Mar 23 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
PHP 截取字符串函数整理(支持gb2312和utf-8)
2010/02/16 PHP
四个常见html网页乱码问题及解决办法
2015/09/08 PHP
基于PHP实现数据分页显示功能
2016/05/26 PHP
PHP自定义序列化接口Serializable用法分析
2017/12/29 PHP
Yii框架中用response保存cookie,用request读取cookie的原理解析
2019/09/04 PHP
JavaScript 操作键盘的Enter事件(键盘任何事件),兼容多浏览器
2010/10/11 Javascript
nodejs实用示例 缩址还原
2010/12/28 NodeJs
jquery在Chrome下获取图片的长宽问题解决
2013/03/20 Javascript
动态加载js和css(外部文件)
2013/04/17 Javascript
js 弹出新页面避免被浏览器、ad拦截的一种新方法
2014/04/30 Javascript
jQuery ajax serialize() 方法使用示例
2014/11/02 Javascript
javascript实现微信分享
2014/12/23 Javascript
jQuery中:empty选择器用法实例
2014/12/30 Javascript
JQUERY表单暂存功能插件分享
2016/02/23 Javascript
使用JavaScript脚本判断页面是否在微信中被打开
2016/03/06 Javascript
浅谈JavaScript的全局变量与局部变量
2016/06/10 Javascript
js实现常见的工具条效果
2017/03/02 Javascript
vue 获取及修改store.js里的公共变量实例
2019/11/06 Javascript
JavaScript canvas动画实现时钟效果
2020/02/10 Javascript
JavaScript canvas实现雨滴特效
2021/01/10 Javascript
python实现数据分析与建模
2019/07/11 Python
OpenCV python sklearn随机超参数搜索的实现
2020/01/17 Python
Python小白不正确的使用类变量实例
2020/05/29 Python
Python flask框架端口失效解决方案
2020/06/04 Python
解决PyCharm无法使用lxml库的问题(图解)
2020/12/22 Python
python常量折叠基础知识点讲解
2021/02/28 Python
英国在线购买轮胎、预订汽车、汽车维修和装配网站:Protyre
2020/04/12 全球购物
如何清空Session
2015/02/23 面试题
艺术设计专业求职自荐信
2014/05/19 职场文书
通信工程求职信
2014/07/16 职场文书
机电一体化应届生求职信
2014/08/09 职场文书
殡葬服务心得体会
2014/09/11 职场文书
欢迎新生标语
2014/10/06 职场文书
群众路线教育实践活动总结
2014/10/30 职场文书
汽车转让协议书
2015/01/29 职场文书
2016年感恩节活动总结大全
2016/04/01 职场文书