tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python切片用法实例教程
Sep 08 Python
使用C语言扩展Python程序的简单入门指引
Apr 14 Python
Python编程之属性和方法实例详解
May 19 Python
使用Python获取网段IP个数以及地址清单的方法
Nov 01 Python
解决python 未发现数据源名称并且未指定默认驱动程序的问题
Dec 07 Python
Python实现一个数组除以一个数的例子
Jul 20 Python
django ModelForm修改显示缩略图 imagefield类型的实例
Jul 28 Python
Python绘制全球疫情变化地图的实例代码
Apr 20 Python
keras训练浅层卷积网络并保存和加载模型实例
Jul 02 Python
python 如何使用find和find_all爬虫、找文本的实现
Oct 16 Python
Python函数调用追踪实现代码
Nov 27 Python
python实现简单猜单词游戏
Dec 24 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
解析PHP多种序列化与反序列化的方法
2013/06/06 PHP
解析argc argv在php中的应用
2013/06/24 PHP
dedecms函数分享之获取某一栏目所有子栏目
2014/05/19 PHP
php中将一个对象保存到Session中的方法
2015/03/13 PHP
PHP中的switch语句的用法实例详解
2015/10/21 PHP
简单介绍PHP非阻塞模式
2016/03/03 PHP
php面向对象值单例模式
2016/05/03 PHP
PHP递归获取目录内所有文件的实现方法
2016/11/01 PHP
Yii2框架配置文件(Application属性)与调试技巧实例分析
2019/05/27 PHP
thinkPHP5.1框架中Request类四种调用方式示例
2019/08/03 PHP
用JavaScript编写COM组件的步骤
2009/03/17 Javascript
关闭浏览器窗口弹出提示框并且可以控制其失效
2014/04/15 Javascript
解决ueditor jquery javascript 取值问题
2014/12/30 Javascript
JS获取和修改元素样式的实例代码
2016/08/06 Javascript
bootstrap手风琴制作方法详解
2017/01/11 Javascript
jQuery Datatable 多个查询条件自定义提交事件(推荐)
2017/08/24 jQuery
vue实现样式之间的切换及vue动态样式的实现方法
2017/12/19 Javascript
vue项目关闭eslint校验
2018/03/21 Javascript
基于vue循环列表时点击跳转页面的方法
2018/08/31 Javascript
node中的session的具体使用
2018/09/14 Javascript
微信小程序按钮点击跳转页面详解
2019/05/06 Javascript
微信小程序云开发(数据库)详解
2019/05/17 Javascript
微信小程序实现商城倒计时
2020/11/01 Javascript
python实现可将字符转换成大写的tcp服务器实例
2015/04/29 Python
Python实现基于POS算法的区块链
2018/08/07 Python
Python通过for循环理解迭代器和生成器实例详解
2019/02/16 Python
python深copy和浅copy区别对比解析
2019/12/26 Python
python爬取王者荣耀全皮肤的简单实现代码
2020/01/31 Python
css3的过滤效果简单实例
2016/08/03 HTML / CSS
美国亚马逊旗下时尚女装网店:SHOPBOP(支持中文)
2020/10/17 全球购物
幼儿园教师考核制度
2014/02/01 职场文书
项目建议书格式
2014/03/12 职场文书
消防安全宣传口号
2014/06/10 职场文书
休学证明范本
2015/06/19 职场文书
先进个人主要事迹怎么写
2015/11/04 职场文书
2016年党员承诺书范文
2016/03/24 职场文书