TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python读取注册表中值的方法
Apr 08 Python
python使用scrapy解析js示例
Jan 23 Python
python实现在每个独立进程中运行一个函数的方法
Apr 23 Python
python图像常规操作
Nov 11 Python
python爬取各类文档方法归类汇总
Mar 22 Python
Python3 jupyter notebook 服务器搭建过程
Nov 30 Python
Python随机生成身份证号码及校验功能
Dec 04 Python
python基于json文件实现的gearman任务自动重启代码实例
Aug 13 Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 Python
python tornado修改log输出方式
Nov 18 Python
如何理解Python中的变量
Jun 01 Python
python pygame 开发五子棋双人对弈
May 02 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
PHP 防恶意刷新实现代码
2010/05/16 PHP
PHP警告Cannot use a scalar value as an array的解决方法
2012/01/11 PHP
JScript中的"this"关键字使用方式补充材料
2007/03/08 Javascript
ExtJS 下拉多选框lovcombo
2010/05/19 Javascript
当某个文本框成为焦点时即清除文本框内容
2014/04/28 Javascript
jQuery实现的一个自定义Placeholder属性插件
2014/08/11 Javascript
浅谈javascript中replace()方法
2015/11/10 Javascript
jQuery实现三级菜单的代码
2016/05/09 Javascript
jQuery插件zTree实现的多选树效果示例
2017/03/08 Javascript
ES6(ECMAScript 6)新特性之模板字符串用法分析
2017/04/01 Javascript
JS编写兼容IE6,7,8浏览器无缝自动轮播
2018/10/12 Javascript
javascript实现简易聊天室
2019/07/12 Javascript
微信小程序 腾讯地图SDK 获取当前地址实现解析
2019/08/12 Javascript
JavaScript 事件代理需要注意的地方
2020/09/08 Javascript
[03:42]2016国际邀请赛中国区预选赛首日现场玩家采访
2016/06/26 DOTA
python中xrange用法分析
2015/04/15 Python
python数字图像处理之高级滤波代码详解
2017/11/23 Python
Python引用计数操作示例
2018/08/23 Python
Python使用graphviz画流程图过程解析
2020/03/31 Python
python进度条显示-tqmd模块的实现示例
2020/08/23 Python
python解决OpenCV在读取显示图片的时候闪退的问题
2021/02/23 Python
一款css实现的鼠标经过按钮的特效
2014/09/11 HTML / CSS
CSS3属性background-size使用指南
2014/12/09 HTML / CSS
CSS Houdini实现动态波浪纹效果
2019/07/30 HTML / CSS
html5将图片转换成base64的实例代码
2016/09/21 HTML / CSS
Peter Alexander新西兰站:澳大利亚领先的睡衣设计师品牌
2016/12/10 全球购物
JD Sports瑞典:英国领先的运动时尚商店
2018/01/28 全球购物
HTC VIVE美国官网:VR虚拟现实眼镜
2018/02/13 全球购物
联想韩国官网:Lenovo Korea
2018/05/10 全球购物
数学与统计学院学生个人职业生涯规划书
2014/02/10 职场文书
孝敬父母的演讲稿
2014/05/14 职场文书
党性分析自查总结
2014/10/14 职场文书
2014年技术员工作总结
2014/11/18 职场文书
运动会入场词
2015/07/18 职场文书
python实战之90行代码写个猜数字游戏
2021/04/22 Python
nginx代理实现静态资源访问的示例代码
2022/07/07 Servers