TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python中元祖、列表和字典的区别
Aug 17 Python
python WindowsError的错误代码详解
Jul 23 Python
Python模块文件结构代码详解
Feb 03 Python
Python实现求一个集合所有子集的示例
May 04 Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 Python
wxpython实现按钮切换界面的方法
Nov 19 Python
Python 实现opencv所使用的图片格式与 base64 转换
Jan 09 Python
Python版中国省市经纬度
Feb 11 Python
计算Python Numpy向量之间的欧氏距离实例
May 22 Python
将pycharm配置为matlab或者spyder的用法说明
Jun 08 Python
Python使用内置函数setattr设置对象的属性值
Oct 16 Python
Numpy中的数组搜索中np.where方法详细介绍
Jan 08 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
php中长文章分页显示实现代码
2012/09/29 PHP
浅析PHP中call user func()函数及如何使用call user func调用自定义函数
2015/11/05 PHP
PHP实现基本留言板功能原理与步骤详解
2020/03/26 PHP
JQuery 构建客户/服务分离的链接模型中Table中的排序分析
2010/01/22 Javascript
JavaScript中Cookie操作实例
2015/01/09 Javascript
jQuery数据缓存用法分析
2015/02/20 Javascript
JS限制文本框只能输入数字和字母方法
2015/02/28 Javascript
基于jquery实现放大镜效果
2015/08/17 Javascript
基于MVC4+EasyUI的Web开发框架形成之旅之界面控件的使用
2015/12/16 Javascript
Angularjs中的ui-bootstrap的使用教程
2017/02/19 Javascript
Angular父组件调用子组件的方法
2018/04/02 Javascript
vue.draggable实现表格拖拽排序效果
2018/12/01 Javascript
微信小程序使用setData修改数组中单个对象的方法分析
2018/12/30 Javascript
写给新手同学的vuex快速上手指北小结
2020/04/14 Javascript
[03:28]2014DOTA2国际邀请赛 走近EG战队天才中单Arteezy
2014/07/12 DOTA
跟老齐学Python之玩转字符串(2)
2014/09/14 Python
python读取excel表格生成erlang数据
2017/08/26 Python
Python虚拟环境项目实例
2017/11/20 Python
Python中enumerate()函数编写更Pythonic的循环
2018/03/06 Python
Flask框架学习笔记之模板操作实例详解
2019/08/15 Python
python 工具 字符串转numpy浮点数组的实现
2020/03/14 Python
Python使用文件操作实现一个XX信息管理系统的示例
2020/07/02 Python
英国最大最好的无人机商店:Drones Direct
2019/07/12 全球购物
美国价格实惠的在线眼镜网站:Zeelool
2020/12/25 全球购物
C#软件工程师英语面试题
2015/06/07 面试题
门诊挂号室室长岗位职责
2013/11/27 职场文书
运动会稿件300字
2014/02/14 职场文书
产品开发计划书
2014/04/27 职场文书
个人授权委托书范文
2014/09/21 职场文书
公司租车协议书
2015/01/29 职场文书
医生辞职信范文
2015/03/02 职场文书
2015年食堂工作总结报告
2015/04/23 职场文书
公司备用金管理制度
2015/08/04 职场文书
2016年“六一儿童节”校园广播稿
2015/12/17 职场文书
html form表单基础入门案例讲解
2021/07/15 HTML / CSS
GoFrame基于性能测试得知grpool使用场景
2022/06/21 Golang