tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字符串加密解密的三种方法分享(base64 win32com)
Jan 19 Python
详解Python中time()方法的使用的教程
May 22 Python
python实现查找两个字符串中相同字符并输出的方法
Jul 11 Python
关于pip的安装,更新,卸载模块以及使用方法(详解)
May 19 Python
微信跳一跳python自动代码解读1.0
Jan 12 Python
Python数据集切分实例
Dec 08 Python
python创造虚拟环境方法总结
Mar 04 Python
pytorch实现Tensor变量之间的转换
Feb 17 Python
Python: tkinter窗口屏幕居中,设置窗口最大,最小尺寸实例
Mar 04 Python
构建高效的python requests长连接池详解
May 02 Python
Python获取android设备cpu和内存占用情况
Nov 15 Python
详解python第三方库的安装、PyInstaller库、random库
Mar 03 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
《PHP边学边教》(02.Apache+PHP环境配置――下篇)
2006/12/13 PHP
php class类的用法详细总结
2013/10/17 PHP
PHP中round()函数对浮点数进行四舍五入的方法
2014/11/19 PHP
如何使用jQuery+PHP+MySQL来实现一个在线测试项目
2015/04/26 PHP
php通过排列组合实现1到9数字相加都等于20的方法
2015/08/03 PHP
php实现倒计时效果
2015/12/19 PHP
PHP页面跳转操作实例分析(header方法)
2016/09/28 PHP
PHP的中使用非缓冲模式查询数据库的方法
2017/02/05 PHP
JQuery 浮动导航栏实现代码
2009/08/27 Javascript
Javascript的构造函数和constructor属性
2010/01/09 Javascript
Javascript学习笔记2 函数
2010/01/11 Javascript
jquery中使用$(#form).submit()重写提交表单无效原因分析及解决
2013/03/25 Javascript
Javascript 两种刷新方法以及区别和适用范围
2017/01/17 Javascript
Vue实现一个返回顶部backToTop组件
2017/07/25 Javascript
Angular项目从新建、打包到nginx部署全过程记录
2017/12/09 Javascript
js实现各浏览器全屏代码实例
2018/07/03 Javascript
[56:29]Secret vs Optic 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
Python中处理字符串之endswith()方法的使用简介
2015/05/18 Python
Django2.1集成xadmin管理后台所遇到的错误集锦(填坑)
2018/12/20 Python
详解如何在Apache中运行Python WSGI应用
2019/01/02 Python
Python如何获得百度统计API的数据并发送邮件示例代码
2019/01/27 Python
Python+selenium点击网页上指定坐标的实例
2019/07/05 Python
python中strip(),lstrip(),rstrip()函数的使用讲解
2020/11/17 Python
中国一家专注拼团的社交购物网站:拼多多
2018/06/13 全球购物
世界上最伟大的马产品:Equiderma
2020/01/07 全球购物
工程造价与财务管理专业应届生求职信
2013/10/06 职场文书
大专生简历的自我评价
2013/11/26 职场文书
个人自我鉴定写法
2013/11/30 职场文书
《一件运动衫》教学反思
2014/02/19 职场文书
市场营销专业毕业生求职信
2014/03/26 职场文书
四风问题对照检查材料思想汇报
2014/10/07 职场文书
交通事故死亡赔偿协议书
2014/12/03 职场文书
优秀校长事迹材料
2014/12/24 职场文书
商标侵权律师函
2015/05/27 职场文书
天气温馨提示语
2015/07/14 职场文书
一文搞懂MySQL索引页结构
2022/02/28 MySQL