TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中基础的socket编程实战攻略
Jun 01 Python
详解Django之auth模块(用户认证)
Apr 17 Python
python一键去抖音视频水印工具
Sep 14 Python
在Python中输入一个以空格为间隔的数组方法
Nov 13 Python
python实现五子棋小游戏
Mar 25 Python
使用pytorch完成kaggle猫狗图像识别方式
Jan 10 Python
django有外键关系的两张表如何相互查找
Feb 10 Python
使用Python FastAPI构建Web服务的实现
Jun 08 Python
Python使用windows设置定时执行脚本
Nov 12 Python
python 爬虫如何实现百度翻译
Nov 16 Python
如何利用python正则表达式匹配版本信息
Dec 09 Python
python接口测试返回数据为字典取值方式
Feb 12 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
php 图片上添加透明度渐变的效果
2009/06/29 PHP
PHP代码实现爬虫记录――超管用
2015/07/31 PHP
PHP模块化安装教程
2016/06/01 PHP
通过身份证号得到出生日期和性别的js代码
2009/11/23 Javascript
js中判断文本框是否为空的两种方法
2011/07/31 Javascript
JavaScript中常用的六种互动方法示例
2015/03/13 Javascript
jQuery对html元素的取值与赋值实例详解
2015/12/18 Javascript
JS实现的倒计时效果实例(2则实例)
2015/12/23 Javascript
详解js私有作用域中创建特权方法
2016/01/25 Javascript
Nodejs中的this详解
2016/03/26 NodeJs
解决淘宝cnpm 安装后cnpm不是内部或外部命令的问题
2018/05/17 Javascript
使用RN Animated做一个“添加购物车”动画的方法
2018/09/12 Javascript
vue3.0 CLI - 2.3 - 组件 home.vue 中学习指令和绑定
2018/09/14 Javascript
浅谈一种让小程序支持JSX语法的新思路
2019/06/16 Javascript
微信小程序实现3D轮播图效果(非swiper组件)
2019/09/21 Javascript
jquery轻量级数字动画插件countUp.js使用详解
2019/10/17 jQuery
JS函数本身的作用域实例分析
2020/03/16 Javascript
Python守护进程和脚本单例运行详解
2017/01/06 Python
python 爬虫一键爬取 淘宝天猫宝贝页面主图颜色图和详情图的教程
2018/05/22 Python
pytorch 模型可视化的例子
2019/08/17 Python
python颜色随机生成器的实例代码
2020/01/10 Python
tensorflow常用函数API介绍
2020/04/19 Python
python:删除离群值操作(每一行为一类数据)
2020/06/08 Python
python进度条显示之tqmd模块
2020/08/22 Python
Java Unsafe类实现原理及测试代码
2020/09/15 Python
Spartoo葡萄牙鞋类网站:线上销售鞋履与时尚配饰
2017/01/11 全球购物
企业安全生产责任书范本
2014/07/28 职场文书
家庭教育的心得体会
2014/09/01 职场文书
2014第二批党的群众路线教育实践活动对照检查材料思想汇报
2014/09/18 职场文书
群众路线教育实践活动民主生活会个人检查对照思想汇报
2014/10/04 职场文书
党的群众路线教育实践活动个人对照检查材料(四风)
2014/11/05 职场文书
三潭印月的导游词
2015/02/12 职场文书
《吃水不忘挖井人》教学反思
2016/02/22 职场文书
只用50行Python代码爬取网络美女高清图片
2021/06/02 Python
Spring Data JPA使用JPQL与原生SQL进行查询的操作
2021/06/15 Java/Android
python_tkinter事件类型详情
2022/03/20 Python