TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于python select.select模块通信的实例讲解
Sep 21 Python
基于Django框架利用Ajax实现点赞功能实例代码
Aug 19 Python
实例讲解Python3中abs()函数
Feb 19 Python
Python中字符串与编码示例代码
May 20 Python
Python学习笔记之列表和成员运算符及列表相关方法详解
Aug 22 Python
使用Pytorch来拟合函数方式
Jan 14 Python
python GUI库图形界面开发之PyQt5信号与槽基础使用方法与实例
Mar 06 Python
opencv python 图片读取与显示图片窗口未响应问题的解决
Apr 24 Python
Python 如何操作 SQLite 数据库
Aug 17 Python
Python Map 函数的使用
Aug 28 Python
Python调用Redis的示例代码
Nov 24 Python
Python可变集合和不可变集合的构造方法大全
Dec 06 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
在php中使用sockets:从新闻组中获取文章
2006/10/09 PHP
ThinkPHP表单自动提交验证实例教程
2014/07/18 PHP
PHP+MySQL统计该库中每个表的记录数并按递减顺序排列的方法
2016/02/15 PHP
thinkPHP5.0框架引入Traits功能实例分析
2017/03/18 PHP
PHP经典实用正则表达式小结
2017/05/04 PHP
Laravel中为什么不使用blpop取队列详析
2018/08/01 PHP
PHP使用gearman进行异步的邮件或短信发送操作详解
2020/02/27 PHP
列表内容的选择
2006/06/30 Javascript
JavaScript中:表达式和语句的区别[译]
2012/09/17 Javascript
JS实现自适应高度表单文本框的方法
2015/02/25 Javascript
js实现的牛顿摆效果
2015/03/31 Javascript
很全面的JavaScript常用功能汇总集合
2016/01/22 Javascript
详解JavaScript的AngularJS框架中的作用域与数据绑定
2016/03/04 Javascript
jquery对dom节点的操作【推荐】
2016/04/15 Javascript
vue-router单页面路由
2017/06/17 Javascript
基于VUE.JS的移动端框架Mint UI的使用
2017/10/11 Javascript
webpack4 SCSS提取和懒加载的示例
2018/09/03 Javascript
如何进行微信公众号开发的本地调试的方法
2019/06/16 Javascript
原生js实现简单轮播图
2020/10/26 Javascript
微信小程序实现弹幕墙(祝福墙)
2020/11/18 Javascript
vue实现登录功能
2020/12/31 Vue.js
[28:48]《真视界》- 2017年国际邀请赛
2017/09/27 DOTA
Python模块搜索概念介绍及模块安装方法介绍
2015/06/03 Python
浅谈python 线程池threadpool之实现
2017/11/17 Python
Python模块文件结构代码详解
2018/02/03 Python
Python图像滤波处理操作示例【基于ImageFilter类】
2019/01/03 Python
pandas进行时间数据的转换和计算时间差并提取年月日
2019/07/06 Python
Python类的绑定方法和非绑定方法实例解析
2020/03/04 Python
Python基于class()实现面向对象原理详解
2020/03/26 Python
Python如何定义有可选参数的元类
2020/07/31 Python
HTML5自定义属性前缀data-及dataset的使用方法(html5 新特性)
2017/08/24 HTML / CSS
世界上最大的各式箱包网络零售店:eBag
2016/07/21 全球购物
护理专科毕业自荐信范文
2014/04/21 职场文书
电台编导求职信
2014/05/06 职场文书
世界遗产导游词
2015/02/13 职场文书
使用 Apache Superset 可视化 ClickHouse 数据的两种方法
2021/07/07 Servers