TensorFlow模型保存和提取的方法


Posted in Python onMarch 08, 2018

一、TensorFlow模型保存和提取方法

1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt") ,实际在这个文件目录下会生成4个人文件:

TensorFlow模型保存和提取的方法

checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

2. 加载这个已保存的TensorFlow模型的方法是saver.restore(sess,"./Model/model.ckpt") ,加载模型的代码中也要定义TensorFlow计算图上的所有运算并声明一个tf.train.Saver类,不同的是加载模型时不需要进行变量的初始化,而是将变量的取值通过保存的模型加载进来,注意加载路径的写法。若不希望重复定义计算图上的运算,可直接加载已经持久化的图,saver =tf.train.import_meta_graph("Model/model.ckpt.meta")

3.tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名},saver = tf.train.Saver({"v1":u1, "v2": u2})即原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中。

4. 上一条做的目的之一就是方便使用变量的滑动平均值。如果在加载模型时直接将影子变量映射到变量自身,则在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。载入时,声明Saver类对象时通过一个字典将滑动平均值直接加载到新的变量中,saver = tf.train.Saver({"v/ExponentialMovingAverage": v}),另通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典。

此外,通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中。

二、TensorFlow程序实现

# 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行 
# 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确 
 
 
# Part1: 通过tf.train.Saver类实现保存和载入神经网络模型 
 
# 执行本段程序时注意当前的工作路径 
import tensorflow as tf 
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2 
 
saver = tf.train.Saver() 
 
with tf.Session() as sess: 
 sess.run(tf.global_variables_initializer()) 
 saver.save(sess, "Model/model.ckpt") 
 
 
# Part2: 加载TensorFlow模型的方法 
 
import tensorflow as tf 
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2 
 
saver = tf.train.Saver() 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./" 
 print(sess.run(result)) # [ 3.] 
 
 
# Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图 
 
import tensorflow as tf 
 
saver = tf.train.import_meta_graph("Model/model.ckpt.meta") 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model.ckpt") # 注意路径写法 
 print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.] 
 
 
# Part4: tf.train.Saver类也支持在保存和加载时给变量重命名 
 
import tensorflow as tf 
 
# 声明的变量名称name与已保存的模型中的变量名称name不一致 
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1") 
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2") 
result = u1 + u2 
 
# 若直接生命Saver类对象,会报错变量找不到 
# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名} 
# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中 
saver = tf.train.Saver({"v1": u1, "v2": u2}) 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model.ckpt") 
 print(sess.run(result)) # [ 3.] 
 
 
# Part5: 保存滑动平均模型 
 
import tensorflow as tf 
 
v = tf.Variable(0, dtype=tf.float32, name="v") 
for variables in tf.global_variables(): 
 print(variables.name) # v:0 
 
ema = tf.train.ExponentialMovingAverage(0.99) 
maintain_averages_op = ema.apply(tf.global_variables()) 
for variables in tf.global_variables(): 
 print(variables.name) # v:0 
       # v/ExponentialMovingAverage:0 
 
saver = tf.train.Saver() 
 
with tf.Session() as sess: 
 sess.run(tf.global_variables_initializer()) 
 sess.run(tf.assign(v, 10)) 
 sess.run(maintain_averages_op) 
 saver.save(sess, "Model/model_ema.ckpt") 
 print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905] 
 
 
# Part6: 通过变量重命名直接读取变量的滑动平均值 
 
import tensorflow as tf 
 
v = tf.Variable(0, dtype=tf.float32, name="v") 
saver = tf.train.Saver({"v/ExponentialMovingAverage": v}) 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model_ema.ckpt") 
 print(sess.run(v)) # 0.0999999 
 
 
# Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典 
 
import tensorflow as tf 
 
v = tf.Variable(0, dtype=tf.float32, name="v") 
# 注意此处的变量名称name一定要与已保存的变量名称一致 
ema = tf.train.ExponentialMovingAverage(0.99) 
print(ema.variables_to_restore()) 
# {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} 
# 此处的v取自上面变量v的名称name="v" 
 
saver = tf.train.Saver(ema.variables_to_restore()) 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model_ema.ckpt") 
 print(sess.run(v)) # 0.0999999 
 
 
# Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中 
 
import tensorflow as tf 
from tensorflow.python.framework import graph_util 
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2 
 
with tf.Session() as sess: 
 sess.run(tf.global_variables_initializer()) 
 # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分 
 graph_def = tf.get_default_graph().as_graph_def() 
 output_graph_def = graph_util.convert_variables_to_constants(sess, 
              graph_def, ['add']) 
 
 with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f: 
  f.write(output_graph_def.SerializeToString()) 
 
 
# Part9: 载入包含变量及其取值的模型 
 
import tensorflow as tf 
from tensorflow.python.platform import gfile 
 
with tf.Session() as sess: 
 model_filename = "Model/combined_model.pb" 
 with gfile.FastGFile(model_filename, 'rb') as f: 
  graph_def = tf.GraphDef() 
  graph_def.ParseFromString(f.read()) 
 
 result = tf.import_graph_def(graph_def, return_elements=["add:0"]) 
 print(sess.run(result)) # [array([ 3.], dtype=float32)]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的二维码生成小软件
Jul 11 Python
python3编写C/S网络程序实例教程
Aug 25 Python
Python中的hypot()方法使用简介
May 18 Python
简单谈谈Python中的json与pickle
Jul 19 Python
python实现归并排序算法
Nov 22 Python
Python函数装饰器实现方法详解
Dec 22 Python
用Python解数独的方法示例
Oct 24 Python
Python for循环通过序列索引迭代过程解析
Feb 07 Python
使用python实现多维数据降维操作
Feb 24 Python
PyQt5+python3+pycharm开发环境配置教程
Mar 24 Python
python批量生成身份证号到Excel的两种方法实例
Jan 14 Python
python pyhs2 的安装操作
Apr 07 Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
TensorFlow模型保存/载入的两种方法
Mar 08 #Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
You might like
PHP连接sql server 2005环境配置及问题解决
2014/08/08 PHP
WordPress中用于获取搜索表单的PHP函数使用解析
2016/01/05 PHP
PHP实现微信公众号验证Token的示例代码
2019/12/16 PHP
JS location几个方法小姐
2008/07/09 Javascript
基于jquery的获取浏览器窗口大小的代码
2011/03/28 Javascript
Javascript面向对象编程
2012/03/18 Javascript
JavaScript高级程序设计 阅读笔记(十二) js内置对象Math
2012/08/14 Javascript
javascript scrollTop正解使用方法
2013/11/14 Javascript
JS实现的一个简单的Autocomplete自动完成例子
2014/04/16 Javascript
浅析Node.js中的内存泄漏问题
2015/06/23 Javascript
Underscore.js 1.3.3 中文注释翻译说明
2015/06/25 Javascript
PHP+jQuery+Ajax+Mysql如何实现发表心情功能
2015/08/06 Javascript
javascript 中Cookie读、写与删除操作
2017/03/29 Javascript
NodeJS如何实现同步的方法示例
2018/08/24 NodeJs
vue 项目中使用Loading组件的示例代码
2018/08/31 Javascript
express如何解决ajax跨域访问session失效问题详解
2019/06/20 Javascript
JS对象属性的检测与获取操作实例分析
2020/03/17 Javascript
JS获取一个字符串中指定字符串第n次出现的位置
2021/02/10 Javascript
树莓派中python获取GY-85九轴模块信息示例
2013/12/05 Python
Python对列表排序的方法实例分析
2015/05/16 Python
解决Python出现_warn_unsafe_extraction问题的方法
2016/03/24 Python
Python简单实现的代理服务器端口映射功能示例
2018/04/08 Python
Pandas读写CSV文件的方法示例
2019/03/27 Python
Html5移动端适配IphoneX等机型的方法
2019/06/25 HTML / CSS
柒牌官方商城:中国男装优秀品牌
2017/06/30 全球购物
芬兰灯具网上商店:Nettilamppu.fi
2018/06/30 全球购物
图库照片、免版税图片、矢量艺术、视频片段:Depositphotos
2019/08/02 全球购物
社区食品安全实施方案
2014/03/28 职场文书
学习雷锋倡议书
2014/04/15 职场文书
宣传普通话标语
2014/06/27 职场文书
人事行政专员岗位职责
2014/07/23 职场文书
2014年度思想工作总结
2014/11/27 职场文书
会计出纳岗位职责
2015/03/31 职场文书
人事行政主管岗位职责
2015/04/09 职场文书
副校长2015年教育教学工作总结
2015/07/27 职场文书
学校教代会开幕词
2016/03/04 职场文书