TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python列表的常用操作方法小结
May 21 Python
python 链接和操作 memcache方法
Mar 04 Python
Python中xrange与yield的用法实例分析
Dec 26 Python
Python批量发送post请求的实现代码
May 05 Python
python3.5绘制随机漫步图
Aug 27 Python
浅谈Python中的全局锁(GIL)问题
Jan 11 Python
Python如何爬取微信公众号文章和评论(基于 Fiddler 抓包分析)
Jun 28 Python
利用Python绘制Jazz网络图的例子
Nov 21 Python
如何通过Django使用本地css/js文件
Jan 20 Python
python 下载m3u8视频的示例代码
Nov 11 Python
Selenium关闭INFO:CONSOLE提示的解决
Dec 07 Python
Python网络编程之ZeroMQ知识总结
Apr 25 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
第1次亲密接触PHP5(2)
2006/10/09 PHP
老生常谈PHP位运算的用途
2017/03/12 PHP
php框架CodeIgniter主从数据库配置方法分析
2018/05/25 PHP
JQuery.uploadify 上传文件插件的使用详解 for ASP.NET
2010/01/22 Javascript
一个奇葩的最短的 IE 版本判断JS脚本
2014/05/28 Javascript
javascript表单验证大全
2015/08/12 Javascript
jQuery焦点图左右转换效果
2016/12/12 Javascript
jQuery表单插件ajaxForm实例详解
2017/01/17 Javascript
JS实现按钮控制计时开始和停止功能
2017/07/27 Javascript
Vue-Quill-Editor富文本编辑器的使用教程
2018/09/21 Javascript
用Vue.js方法创建模板并使用多个模板合成
2019/06/28 Javascript
深入webpack打包原理及loader和plugin的实现
2020/05/06 Javascript
2020淘宝618理想生活列车自动领喵币js脚本的代码
2020/06/02 Javascript
[00:32]2018DOTA2亚洲邀请赛Secret出场
2018/04/03 DOTA
SQLite3中文编码 Python的实现
2017/01/11 Python
python删除某个字符
2018/03/19 Python
Python可变参数*args和**kwargs用法实例小结
2018/04/27 Python
Python 普通最小二乘法(OLS)进行多项式拟合的方法
2018/12/29 Python
python  logging日志打印过程解析
2019/10/22 Python
Django自定义列表 models字段显示方式
2020/04/03 Python
突袭HTML5之Javascript API扩展5—其他扩展(应用缓存/服务端消息/桌面通知)
2013/01/31 HTML / CSS
html5的自定义data-*属性与jquery的data()方法的使用
2014/07/02 HTML / CSS
美国最顶级的精品店之一:Hampden Clothing
2016/12/22 全球购物
北京麒麟网信息技术有限公司网络游戏测试面试题
2013/09/28 面试题
优秀团员个人的自我评价
2013/10/02 职场文书
生产部经理岗位职责
2013/12/16 职场文书
班级德育工作实施方案
2014/02/21 职场文书
实习护士自荐信
2014/06/21 职场文书
安全环保演讲稿
2014/08/28 职场文书
2014年终工作总结范本
2014/12/15 职场文书
大学教师个人总结
2015/02/10 职场文书
社区六一儿童节活动总结
2015/02/11 职场文书
2015年化验室工作总结
2015/04/23 职场文书
2015年语文教师工作总结
2015/05/25 职场文书
合理缓解职场压力,让你随时保持最佳状态!
2019/06/21 职场文书
前端canvas中物体边框和控制点的实现示例
2022/08/05 Javascript