TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
开源软件包和环境管理系统Anaconda的安装使用
Sep 04 Python
python中装饰器级连的使用方法示例
Sep 29 Python
python3+PyQt5使用数据库窗口视图
Apr 24 Python
python版飞机大战代码分享
Nov 20 Python
python使用pipeline批量读写redis的方法
Feb 18 Python
scrapy-redis的安装部署步骤讲解
Feb 27 Python
Django框架模板文件使用及模板文件加载顺序分析
May 23 Python
python elasticsearch从创建索引到写入数据的全过程
Aug 04 Python
scrapy头部修改的方法详解
Dec 06 Python
用Python制作音乐海报
Jan 26 Python
Python页面加载的等待方式总结
Feb 28 Python
Python字符串格式化方式
Apr 07 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
php框架Phpbean说明
2008/01/10 PHP
php PDO中文乱码解决办法
2009/07/20 PHP
浅析PHP中json_encode与json_decode的区别
2020/07/15 PHP
一个无限级XML绑定跨框架菜单(For IE)
2007/01/27 Javascript
jQuery 中关于CSS操作部分使用说明
2007/06/10 Javascript
Json对象替换字符串占位符实现代码
2010/11/17 Javascript
扩展Jquery插件处理mouseover时内部有子元素时发生样式闪烁
2011/12/08 Javascript
javascript学习笔记(三) String 字符串类型介绍
2012/06/19 Javascript
表格单元格交错着色实现思路及代码
2013/04/01 Javascript
js实现动态添加、删除行、onkeyup表格求和示例
2013/08/18 Javascript
jquery批量设置属性readonly和disabled的方法
2014/01/24 Javascript
浅析JavaScript中的事件机制
2015/06/04 Javascript
JAVASCRIPT代码编写俄罗斯方块网页版
2015/11/26 Javascript
浅析AngularJS中的指令
2016/03/20 Javascript
JS 实现导航菜单中的二级下拉菜单的几种方式
2016/10/31 Javascript
JS 终止执行的实现方法
2016/11/24 Javascript
React项目动态设置title标题的方法示例
2018/09/26 Javascript
微信小程序websocket实现即时聊天功能
2019/05/21 Javascript
python使用生成器实现可迭代对象
2018/03/20 Python
pyttsx3实现中文文字转语音的方法
2018/12/24 Python
jenkins配置python脚本定时任务过程图解
2019/10/29 Python
Python面向对象程序设计之静态方法、类方法、属性方法原理与用法分析
2020/03/23 Python
Python sorted对list和dict排序
2020/06/09 Python
Python使用socket模块实现简单tcp通信
2020/08/18 Python
Java的接口和C++的虚类的相同和不同处
2014/03/27 面试题
一些Solaris面试题
2015/12/22 面试题
工商企业管理应届生求职信
2013/11/03 职场文书
工程资料员岗位职责
2014/03/10 职场文书
企业精神口号
2014/06/11 职场文书
五一口号
2014/06/19 职场文书
海洋科学专业求职信
2014/08/10 职场文书
个人工作总结范文2014
2014/11/07 职场文书
专家推荐信怎么写
2015/03/25 职场文书
小学运动会前导词
2015/07/20 职场文书
python实战之90行代码写个猜数字游戏
2021/04/22 Python
给numpy.array增加维度的超简单方法
2021/06/02 Python