TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据库操作常用功能使用详解(创建表/插入数据/获取数据)
Dec 06 Python
Linux下使用python调用top命令获得CPU利用率
Mar 10 Python
使用PyCharm配合部署Python的Django框架的配置纪实
Nov 19 Python
简单谈谈Python中的闭包
Nov 30 Python
Django 限制用户访问频率的中间件的实现
Aug 23 Python
Python简单处理坐标排序问题示例
Jul 11 Python
python中enumerate() 与zip()函数的使用比较实例分析
Sep 03 Python
python实现统计代码行数的小工具
Sep 19 Python
python实现回旋矩阵方式(旋转矩阵)
Dec 04 Python
解决python pandas读取excel中多个不同sheet表格存在的问题
Jul 14 Python
tensorflow与numpy的版本兼容性问题的解决
Jan 08 Python
jupyter notebook远程访问不了的问题解决方法
Jan 11 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
一个简易需要注册的留言版程序
2006/10/09 PHP
PHP数组实例总结与说明
2011/08/23 PHP
PHP开发注意事项总结
2015/02/04 PHP
PHP图像处理类库及演示分享
2015/05/17 PHP
PHP5.3连接Oracle客户端及PDO_OCI模块的安装方法
2016/05/13 PHP
JavaScript CSS修改学习第一章 查找位置
2010/02/19 Javascript
某页码显示的helper 少量调整,另附js版
2010/09/12 Javascript
IFrame跨域高度自适应实现代码
2012/08/16 Javascript
appendChild() 或 insertBefore()使用与区别介绍
2013/10/11 Javascript
js字符串转换成数字与数字转换成字符串的实现方法
2014/01/08 Javascript
jquery判断小数点两位和自动删除小数两位后的数字
2014/03/19 Javascript
简单掌握JavaScript中const声明常量与变量的用法
2016/05/21 Javascript
Angular设置title信息解决SEO方面存在问题
2016/08/19 Javascript
js注入 黑客之路必备!
2016/09/14 Javascript
weUI应用之JS常用信息提示弹层的封装
2016/11/21 Javascript
JS实现全屏预览F11功能的示例代码
2018/07/23 Javascript
Vue中的vue-resource示例详解
2018/11/02 Javascript
JavaScript禁用右键单击优缺点分析
2019/01/20 Javascript
Python中实现结构相似的函数调用方法
2015/03/10 Python
python从入门到精通(DAY 3)
2015/12/20 Python
Python编程之event对象的用法实例分析
2017/03/23 Python
如何利用python查找电脑文件
2018/04/27 Python
python实现在函数中修改变量值的方法
2019/07/16 Python
python3 反射的四种基本方法解析
2019/08/26 Python
在pycharm中为项目导入anacodna环境的操作方法
2020/02/12 Python
在django项目中导出数据到excel文件并实现下载的功能
2020/03/13 Python
pycharm下pyqt4安装及环境配置的教程
2020/04/24 Python
python能否java成为主流语言吗
2020/06/22 Python
python的flask框架难学吗
2020/07/31 Python
HTML5之WebGL 3D概述(下)—借助类库开发及框架介绍
2013/01/31 HTML / CSS
阿迪达斯印度官方商城:adidas India
2017/03/26 全球购物
Bench加拿大官方网站:英国城市服装品牌
2017/11/03 全球购物
美国领先的在线邮轮旅游公司:CruiseDirect
2018/06/07 全球购物
雷锋精神演讲稿
2014/05/13 职场文书
导游词之京东大峡谷旅游区
2019/10/29 职场文书
python flappy bird小游戏分步实现流程
2022/02/15 Python