tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django框架中方法的访问和查找
Jul 15 Python
Python实现登录接口的示例代码
Jul 21 Python
python中通过预先编译正则表达式提高效率
Sep 25 Python
python中的计时器timeit的使用方法
Oct 20 Python
python通过TimedRotatingFileHandler按时间切割日志
Jul 17 Python
解决Python设置函数调用超时,进程卡住的问题
Aug 08 Python
解决Atom安装Hydrogen无法运行python3的问题
Aug 28 Python
python 函数中的参数类型
Feb 11 Python
快速解决jupyter notebook启动需要密码的问题
Apr 21 Python
Python PyQt5运行程序把输出信息展示到GUI图形界面上
Apr 27 Python
pycharm软件实现设置自动保存操作
Jun 08 Python
Pycharm同步远程服务器调试的方法步骤
Nov 04 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
基于Snoopy的PHP近似完美获取网站编码的代码
2011/10/23 PHP
php获取mysql字段名称和其它信息的例子
2014/04/14 PHP
PHP实现的基于单向链表解决约瑟夫环问题示例
2017/09/30 PHP
PHP rsa加密解密算法原理解析
2020/12/09 PHP
Jquery网页出现的乱码问题的三种解决方法
2013/06/30 Javascript
css样式标签和js语法属性区别
2013/11/06 Javascript
JavaScript按位运算符的应用简析
2014/02/04 Javascript
javascript实现base64 md5 sha1 密码加密
2015/09/09 Javascript
基于jQuery实现Tabs选项卡自定义插件
2016/11/21 Javascript
js获取隐藏元素的宽高
2017/02/24 Javascript
vue 2.5.1 源码学习 之Vue.extend 和 data的合并策略
2019/06/04 Javascript
Angular单元测试之事件触发的实现
2020/01/20 Javascript
vue下拉刷新组件的开发及slot的使用详解
2020/12/23 Vue.js
使用AutoJs实现微信抢红包的代码
2020/12/31 Javascript
js面向对象方式实现拖拽效果
2021/03/03 Javascript
[06:45]DOTA2卡尔工作室 英雄介绍幻影长矛手篇
2013/07/12 DOTA
[04:21]狐狸妈带你到现场 DOTA2 TI中国区预选赛线下赛路线指引
2014/05/22 DOTA
Python创建系统目录的方法
2015/03/11 Python
从Python的源码来解析Python下的freeblock
2015/05/11 Python
Python selenium 三种等待方式详解(必会)
2016/09/15 Python
python 实现有道翻译功能
2021/02/26 Python
Original Penguin美国官网:布拉德皮特、强尼德普喜爱的服装品牌
2016/10/25 全球购物
美国领先的奢侈美容零售商:Bluemercury
2017/07/26 全球购物
美国礼品卡商城: Gift Card Mall
2017/08/25 全球购物
斯图尔特·韦茨曼鞋加拿大官网:Stuart Weitzman加拿大
2019/10/13 全球购物
Java基础类库面试题
2013/09/04 面试题
生产现场工艺工程师岗位职责
2013/11/28 职场文书
机械制造专业个人的自我评价
2013/12/28 职场文书
公务员转正鉴定材料
2014/02/11 职场文书
现金出纳岗位职责
2014/03/15 职场文书
求职信名称怎么写
2014/05/26 职场文书
幼儿教师师德师风演讲稿
2014/08/22 职场文书
四风查摆剖析材料
2014/10/10 职场文书
复试通知单模板
2015/04/24 职场文书
导游词之河北白洋淀
2020/01/15 职场文书
Flink 侧流输出源码示例解析
2022/09/23 Servers