tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中列表的一些基本操作知识汇总
May 20 Python
详解Python爬虫的基本写法
Jan 08 Python
python脚本替换指定行实现步骤
Jul 11 Python
python实现泊松图像融合
Jul 26 Python
python七夕浪漫表白源码
Apr 05 Python
python基于paramiko将文件上传到服务器代码实现
Jul 08 Python
python实现按行分割文件
Jul 22 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 Python
Python IDLE或shell中切换路径的操作
Mar 09 Python
pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)
Jun 24 Python
基于Python实现体育彩票选号器功能代码实例
Sep 16 Python
一文搞懂Python Sklearn库使用
Aug 23 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
异世界新番又来了,同样是从零开始,男主的年龄降到5岁
2020/04/09 日漫
并发下常见的加锁及锁的PHP具体实现代码
2010/10/12 PHP
解析PHP无限级分类方法及代码
2013/06/21 PHP
PHP技术开发微信公众平台
2015/07/22 PHP
orm获取关联表里的属性值
2016/04/17 PHP
PHP与服务器文件系统的简单交互
2016/10/21 PHP
jquery 插件 web2.0分格的分页脚本,可用于ajax无刷新分页
2008/12/25 Javascript
Extjs学习笔记之六 面版
2010/01/08 Javascript
Raphael带文本标签可拖动的图形实现代码
2013/02/20 Javascript
Js 代码中,ajax请求地址后加随机数防止浏览器缓存的原因
2013/05/07 Javascript
关于页面嵌入swf覆盖div层的问题的解决方法
2014/02/11 Javascript
Node.js实现的简易网页抓取功能示例
2014/12/05 Javascript
JavaScript实现横向滑出的多级菜单效果
2015/10/09 Javascript
详解XMLHttpRequest(二)响应属性、二进制数据、监测上传下载进度
2016/09/14 Javascript
Nodejs 和 Electron ubuntu下快速安装过程
2018/05/04 NodeJs
Vue动态生成表格的行和列
2019/07/18 Javascript
layui.use模块外部使用其内部定义的js封装函数方法
2019/09/16 Javascript
JavaScript实现随机点名程序
2020/03/25 Javascript
pyqt和pyside开发图形化界面
2014/01/22 Python
举例讲解Python面相对象编程中对象的属性与类的方法
2016/01/19 Python
python将回车作为输入内容的实例
2018/06/23 Python
Python函数和模块的使用总结
2019/05/20 Python
python中的协程深入理解
2019/06/10 Python
详解Python 多线程 Timer定时器/延迟执行、Event事件
2019/06/27 Python
Python在Matplotlib图中显示中文字体的操作方法
2019/07/29 Python
scrapy框架携带cookie访问淘宝购物车功能的实现代码
2020/07/07 Python
基于HTML5的WebSocket的实例代码
2018/08/15 HTML / CSS
美国女士时尚珠宝及配饰购物网站:Icing
2018/07/02 全球购物
英国蜡烛、蜡烛配件和家居香氛购买网站:Yankee Candle
2018/12/12 全球购物
size?爱尔兰官方网站:英国伦敦的球鞋精品店
2019/03/31 全球购物
Cocopanda波兰:购买化妆品、护肤品、护发和香水
2020/05/25 全球购物
母亲80寿诞答谢词
2014/01/16 职场文书
班级入场式解说词
2014/02/01 职场文书
小学秋季运动会报道稿
2014/09/30 职场文书
2016年端午节红领巾广播稿
2015/12/18 职场文书
win10清理dns缓存
2022/04/19 数码科技