tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现通过shelve修改对象实例
Sep 26 Python
python写入中英文字符串到文件的方法
May 06 Python
在Python中操作文件之truncate()方法的使用教程
May 25 Python
Python多层嵌套list的递归处理方法(推荐)
Jun 08 Python
python使用RNN实现文本分类
May 24 Python
解决Pycharm调用Turtle时 窗口一闪而过的问题
Feb 16 Python
利用Python实现微信找房机器人实例教程
Mar 10 Python
python实现坦克大战游戏 附详细注释
Mar 27 Python
TensorFlow学习之分布式的TensorFlow运行环境
Feb 05 Python
python 调用Google翻译接口的方法
Dec 09 Python
Python爬取梨视频的示例
Jan 29 Python
TensorFlow低版本代码自动升级为1.0版本
Feb 20 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
PHP面向对象精要总结
2014/11/07 PHP
护卫神php套件 php版本升级方法(php5.5.24)
2015/05/10 PHP
PHP借助phpmailer发送邮件
2015/05/11 PHP
php实现短信发送代码
2015/07/05 PHP
php实现购物车产品删除功能(2)
2020/07/23 PHP
Redis构建分布式锁
2017/03/28 PHP
php 浮点数比较方法详解
2017/05/05 PHP
PHP自定义序列化接口Serializable用法分析
2017/12/29 PHP
一组JS创建和操作表格的函数集合
2009/05/07 Javascript
多浏览器支持的右下角浮动窗口
2010/04/01 Javascript
jquery操作select option 的代码小结
2011/06/21 Javascript
JavaScript前补零操作实例
2015/03/11 Javascript
C++中的string类的用法小结
2015/08/07 Javascript
如何使用headjs来管理和异步加载js
2016/11/29 Javascript
AngularJS自定义指令详解(有分页插件代码)
2017/06/12 Javascript
详解webpack介绍&安装&常用命令
2017/06/29 Javascript
js中变量的连续赋值(实例讲解)
2017/07/08 Javascript
微信jssdk逻辑在vue中的运用详解
2018/11/14 Javascript
jQuery实现全选、反选和不选功能的方法详解
2019/12/04 jQuery
Vue绑定用户接口实现代码示例
2020/11/04 Javascript
ReactRouter的实现方法
2021/01/25 Javascript
[00:43]DOTA2小紫本全民票选福利PA至宝全方位展示
2014/11/25 DOTA
利用Python绘制MySQL数据图实现数据可视化
2015/03/30 Python
在Python程序中操作MySQL的基本方法
2015/07/29 Python
python中将函数赋值给变量时需要注意的一些问题
2017/08/18 Python
对tensorflow 的模型保存和调用实例讲解
2018/07/28 Python
python numpy中cumsum的用法详解
2019/10/17 Python
基于Python获取照片的GPS位置信息
2020/01/20 Python
什么是CSS3 HSLA色彩模式?HSLA模拟渐变色条
2016/04/26 HTML / CSS
使用canvas对多图片拼合并导出图片的方法
2018/08/28 HTML / CSS
美国维生素、补充剂、保健食品购物网站:Vitacost
2016/08/05 全球购物
加拿大票务网站:Ticketmaster加拿大
2017/07/17 全球购物
Boston Proper官网:美国女装品牌
2017/10/30 全球购物
烈士陵园观后感
2015/06/08 职场文书
青年志愿者活动感想
2015/08/07 职场文书
优秀家长事迹材料(2016推荐版)
2016/02/29 职场文书