TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python支持断点续传的多线程下载示例
Jan 16 Python
关于Python面向对象编程的知识点总结
Feb 14 Python
Python多线程实现同步的四种方式
May 02 Python
python3库numpy数组属性的查看方法
Apr 17 Python
Python 通配符删除文件的实例
Apr 24 Python
在python2.7中用numpy.reshape 对图像进行切割的方法
Dec 05 Python
使用python判断jpeg图片的完整性实例
Jun 10 Python
python3文件复制、延迟文件复制任务的实现方法
Sep 02 Python
python爬虫添加请求头代码实例
Dec 28 Python
Python递归求出列表(包括列表中的子列表)的最大值实例
Feb 27 Python
使用Keras中的ImageDataGenerator进行批次读图方式
Jun 17 Python
python批量合成bilibili的m4s缓存文件为MP4格式 ver2.5
Dec 01 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
php缓冲输出实例分析
2015/01/05 PHP
PHP实现将几张照片拼接到一起的合成图片功能【便于整体打印输出】
2017/11/14 PHP
Laravel登录失败次数限制的实现方法
2020/08/26 PHP
javascript入门·动态的时钟,显示完整的一些方法,新年倒计时
2007/10/01 Javascript
通过正则格式化url查询字符串实现代码
2012/12/28 Javascript
JQuery操作Select的Options的Bug(IE8兼容性视图模式)
2013/04/21 Javascript
input:checkbox多选框实现单选效果跟radio一样
2014/06/16 Javascript
jQuery中first()方法用法实例
2015/01/06 Javascript
javaScript生成支持中文带logo的二维码(jquery.qrcode.js)
2017/01/03 Javascript
Vue计算属性的学习笔记
2017/03/22 Javascript
微信小程序自定义组件封装及父子间组件传值的方法
2018/08/28 Javascript
利用js-cookie实现前端设置缓存数据定时失效
2019/06/18 Javascript
node.js制作一个简单的登录拦截器
2020/02/10 Javascript
解决vue单页面多个组件嵌套监听浏览器窗口变化问题
2020/07/30 Javascript
[03:00]DOTA2-DPC中国联赛1月18日Recap集锦
2021/03/11 DOTA
Python3读取zip文件信息的方法
2015/05/22 Python
python dataframe常见操作方法:实现取行、列、切片、统计特征值
2018/06/09 Python
Python 文本文件内容批量抽取实例
2018/12/10 Python
pandas读取csv文件,分隔符参数sep的实例
2018/12/12 Python
Python异常处理知识点总结
2019/02/18 Python
python+OpenCV实现车牌号码识别
2019/11/08 Python
使用python实现哈希表、字典、集合操作
2019/12/22 Python
解决pytorch报错:AssertionError: Invalid device id的问题
2020/01/10 Python
CSS3实现超慢速移动动画效果非常流畅无卡顿
2014/06/15 HTML / CSS
HTML5语义化元素你真的用对了吗
2019/08/22 HTML / CSS
国际贸易专业个人鉴定
2014/02/22 职场文书
汽车机修工岗位职责
2014/03/06 职场文书
民族学专业职业生涯规划范文:积跬步以至千里
2014/09/11 职场文书
意外死亡赔偿协议书
2014/10/14 职场文书
国际贸易实务实训报告
2014/11/05 职场文书
钱塘江大潮导游词
2015/02/03 职场文书
婚宴父母致辞
2015/07/27 职场文书
PHP中->和=>的意思
2021/03/31 PHP
POST提交数据常见的四种方式
2022/01/18 HTML / CSS
python 实现图片特效处理
2022/04/03 Python
Ruby序列化和持久化存储 Marshal和Pstore介绍
2022/04/18 Ruby