tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Trie树实现字典排序
Mar 28 Python
Python的Django框架可适配的各种数据库介绍
Jul 15 Python
Python语言描述最大连续子序列和
Dec 05 Python
python实现堆和索引堆的代码示例
Mar 19 Python
Python闭包执行时值的传递方式实例分析
Jun 04 Python
Python用于学习重要算法的模块pygorithm实例浅析
Aug 16 Python
Django将默认的SQLite更换为MySQL的实现
Nov 18 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 Python
Tensorflow 模型转换 .pb convert to .lite实例
Feb 12 Python
jupyter 添加不同内核的操作
Feb 06 Python
如何用python爬取微博热搜数据并保存
Feb 20 Python
python如何读取和存储dict()与.json格式文件
Jun 25 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
用Php实现链结人气统计
2006/10/09 PHP
php中操作memcached缓存进行增删改查数据的实现代码
2014/08/15 PHP
PHP数字前补0的自带函数sprintf 和number_format的用法(详解)
2017/02/06 PHP
php利用ZipArchive类操作文件的实例
2020/01/21 PHP
Javascript 中的类和闭包
2010/01/08 Javascript
读取input:file的路径并显示本地图片的方法
2013/09/23 Javascript
[将免费进行到底]在Amazon的一年免费服务器上安装Node.JS, NPM和OurJS博客
2014/08/18 Javascript
JQuery日期插件datepicker的使用方法
2016/03/03 Javascript
深入理解jQuery之事件移除
2016/06/02 Javascript
仅一个form表单 js实现注册信息依次填写提交功能
2016/06/12 Javascript
作为老司机使用 React 总结的 11 个经验教训
2017/04/08 Javascript
jQuery实现简单的回到顶部totop功能示例
2017/10/16 jQuery
JavaScript 中使用 Generator的方法
2017/12/29 Javascript
NW.js 简介与使用方法
2018/02/01 Javascript
[30:37]【全国守擂赛】第三周擂主赛 Dark Knight vs. Leopard Gaming
2020/05/04 DOTA
更改Python命令行交互提示符的方法
2015/01/14 Python
python实现用户登陆邮件通知的方法
2015/07/09 Python
python实现基于SVM手写数字识别功能
2020/05/27 Python
详解Python3.6的py文件打包生成exe
2018/07/13 Python
python代码过长的换行方法
2018/07/19 Python
Python使用googletrans报错的解决方法
2018/09/25 Python
python实现网页自动签到功能
2019/01/21 Python
PyQt5实现QLineEdit添加clicked信号的方法
2019/06/25 Python
python 常用日期处理-- datetime 模块的使用
2020/09/02 Python
HTML5和以前HTML4的区别整理
2013/10/20 HTML / CSS
Java中compareTo和compare的区别
2016/04/12 面试题
linux面试题参考答案(11)
2012/05/01 面试题
欢送退休感言
2014/02/08 职场文书
总经理助理的职责
2014/03/14 职场文书
教研处工作方案
2014/05/26 职场文书
亲子阅读的活动方案
2014/08/15 职场文书
简爱读书笔记
2015/06/26 职场文书
三八节活动简报
2015/07/20 职场文书
防止web项目中的SQL注入
2021/12/06 MySQL
利用uni-app生成微信小程序的踩坑记录
2022/04/05 Javascript
从原生JavaScript到React深入理解
2022/07/23 Javascript