TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python fabric使用笔记
May 09 Python
Python中asyncore异步模块的用法及实现httpclient的实例
Jun 28 Python
python 安装virtualenv和virtualenvwrapper的方法
Jan 13 Python
Python解析命令行读取参数--argparse模块使用方法
Jan 23 Python
mac安装pytorch及系统的numpy更新方法
Jul 26 Python
python实现在图片上画特定大小角度矩形框
Oct 24 Python
详解Python中is和==的区别
Mar 21 Python
Python创建一个元素都为0的列表实例
Nov 28 Python
Spring Cloud Feign高级应用实例详解
Dec 10 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 Python
python如何实现图片压缩
Sep 11 Python
python regex库实例用法总结
Jan 03 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
thinkphp项目部署到Linux服务器上报错“模板不存在”如何解决
2016/04/27 PHP
php读取和保存base64编码的图片内容
2017/04/22 PHP
php字符串截取函数mb_substr用法实例分析
2019/06/25 PHP
PHP安全之register_globals的on和off的区别
2020/07/23 PHP
JavaScript 滚轮事件使用说明
2010/03/07 Javascript
扩展jquery实现客户端表格的分页、排序功能代码
2011/03/16 Javascript
JavaScript入门之基本函数详解
2011/10/21 Javascript
html超链接打开窗口大小的方法
2013/03/05 Javascript
详解JavaScript的表达式与运算符
2015/11/30 Javascript
jQuery插件编写步骤详解
2016/06/03 Javascript
jQuery ajax中使用confirm,确认是否删除的简单实例
2016/06/17 Javascript
AngularJS实现在ng-Options加上index的解决方法
2016/11/03 Javascript
简单谈谈Javascript函数中的arguments
2017/02/09 Javascript
Nodejs 获取时间加手机标识的32位标识实现代码
2017/03/07 NodeJs
vue实现一个移动端屏蔽滑动的遮罩层实例
2017/06/08 Javascript
Vue子组件向父组件通信与父组件调用子组件中的方法
2018/06/22 Javascript
Element Table的row-class-name无效与动态高亮显示选中行背景色
2018/11/30 Javascript
基于element-ui组件手动实现单选和上传功能
2018/12/06 Javascript
ES6 class的应用实例分析
2019/06/27 Javascript
微信小程序tabBar 返回tabBar不刷新页面
2019/07/25 Javascript
Vue的Eslint配置文件eslintrc.js说明与规则介绍
2020/02/03 Javascript
Vue实现手机计算器
2020/08/17 Javascript
详解实现vue的数据响应式原理
2021/01/20 Vue.js
Python中的引用和拷贝浅析
2014/11/22 Python
Python基于PycURL实现POST的方法
2015/07/25 Python
简单谈谈python中的Queue与多进程
2016/08/25 Python
python使用OpenCV模块实现图像的融合示例代码
2020/04/10 Python
详解pyinstaller生成exe的闪退问题解决方案
2020/06/19 Python
通过实例简单了解python yield使用方法
2020/08/06 Python
HTML5中语义化 b 和 i 标签
2008/10/17 HTML / CSS
什么是用户模式(User Mode)与内核模式(Kernel Mode) ?
2014/07/21 面试题
个人简历自我评价
2014/01/06 职场文书
书法比赛获奖感言
2014/02/10 职场文书
汽车销售经理岗位职责
2014/06/09 职场文书
泸县召开党的群众路线教育实践活动总结大会新闻稿
2014/10/21 职场文书
Python 数据可视化之Matplotlib详解
2021/11/02 Python