tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python制作CSDN免积分下载器
Mar 10 Python
python概率计算器实例分析
Mar 25 Python
Python中input与raw_input 之间的比较
Aug 20 Python
Python 2.7中文显示与处理方法
Jul 16 Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 Python
Python 中pandas索引切片读取数据缺失数据处理问题
Oct 09 Python
对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解
Feb 11 Python
无惧面试,带你搞懂python 装饰器
Aug 17 Python
python安装第三方库如xlrd的方法
Oct 31 Python
Python自动化办公Excel模块openpyxl原理及用法解析
Nov 05 Python
浅谈Python基础之列表那些事儿
May 11 Python
python使用PySimpleGUI设置进度条及控件使用
Jun 10 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
使用sockets:从新闻组中获取文章(一)
2006/10/09 PHP
Yii框架使用PHPExcel导出Excel文件的方法分析【改进版】
2019/07/24 PHP
PHP使用Session实现上传进度功能详解
2019/08/06 PHP
利用jquery实现下拉框的禁用与启用
2016/12/07 Javascript
js记录点击某个按钮的次数-刷新次数为初始状态的实例
2017/02/15 Javascript
详解vue-router2.0动态路由获取参数
2017/06/14 Javascript
JS时间控制实现动态效果的实例讲解
2017/07/31 Javascript
JS简单获得节点元素的方法示例
2018/02/10 Javascript
js中apply()和call()的区别与用法实例分析
2018/08/14 Javascript
vue、react等单页面项目部署到服务器的方法及vue和react的区别
2018/09/29 Javascript
JS中数据结构之栈
2019/01/01 Javascript
Nodejs让异步变成同步的方法
2019/03/02 NodeJs
javascript判断一个变量是数组还是对象
2019/04/10 Javascript
JS如何实现网站中PC端和手机端自动识别并跳转对应的代码
2020/01/08 Javascript
Vue引入Stylus知识点总结
2020/01/16 Javascript
关于angular浏览器兼容性问题的解决方案
2020/07/26 Javascript
js实现简单选项卡制作
2020/08/05 Javascript
使用vue3重构拼图游戏的实现示例
2021/01/25 Vue.js
Python使用稀疏矩阵节省内存实例
2014/06/27 Python
MySQLdb ImportError: libmysqlclient.so.18解决方法
2014/08/21 Python
Python的Bottle框架中获取制定cookie的教程
2015/04/24 Python
Python爬取网易云音乐上评论火爆的歌曲
2017/01/19 Python
Python基于贪心算法解决背包问题示例
2017/11/27 Python
python 模拟银行转账功能过程详解
2019/08/06 Python
如何解决cmd运行python提示不是内部命令
2020/07/01 Python
python实现数学模型(插值、拟合和微分方程)
2020/11/13 Python
澳大利亚墨尔本的在线时装店:LORETA
2018/09/14 全球购物
俄罗斯游戏商店:Buka
2020/03/01 全球购物
JENNIFER BEHR官网:各种耳环和发饰
2020/06/07 全球购物
linux面试题参考答案(6)
2014/08/29 面试题
初中生自我评价
2014/02/01 职场文书
生日宴会主持词
2014/03/20 职场文书
党员群众路线学习心得体会
2014/11/04 职场文书
2015国际残疾人日活动总结
2015/03/24 职场文书
停电放假通知
2015/04/14 职场文书
研究生毕业登记表的自我鉴定范文
2019/07/15 职场文书