TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python去掉字符串中重复字符的方法
Feb 27 Python
Python中使用hashlib模块处理算法的教程
Apr 28 Python
Python for Informatics 第11章之正则表达式(四)
Apr 21 Python
python实现SMTP邮件发送功能
Jun 16 Python
python目录与文件名操作例子
Aug 28 Python
Python快速排序算法实例分析
Nov 29 Python
python tensorflow基于cnn实现手写数字识别
Jan 01 Python
PyQt5的安装配置过程,将ui文件转为py文件后显示窗口的实例
Jun 19 Python
浅析PEP570新语法: 只接受位置参数
Oct 15 Python
详解pyqt5的UI中嵌入matplotlib图形并实时刷新(挖坑和填坑)
Aug 07 Python
Opencv常见图像格式Data Type及代码实例
Nov 02 Python
分析Python感知线程状态的解决方案之Event与信号量
Jun 16 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
在Windows中安装Apache2和PHP4的权威指南
2006/10/09 PHP
php实现粘贴截图并完成上传功能
2015/05/17 PHP
一个完整的PHP类包含的七种语法说明
2015/06/04 PHP
PHP面向对象之领域模型+数据映射器实例(分析)
2017/06/21 PHP
jQuery 使用手册(一)
2009/09/23 Javascript
JS继承 笔记
2011/07/13 Javascript
firefox下jQuery UI Autocomplete 1.8.*中文输入修正方法
2012/09/19 Javascript
得到jQuery detach()后节点中的某个值实现代码
2013/02/05 Javascript
浅析hasOwnProperty方法的应用
2013/11/20 Javascript
jQuery对象初始化的传参方式
2015/02/26 Javascript
Function.prototype.apply()与Function.prototype.call()小结
2016/04/27 Javascript
jQuery自定义多选下拉框效果
2017/06/19 jQuery
基于JS实现移动端左滑删除功能
2017/07/28 Javascript
微信小程序实现topBar底部选择栏效果
2018/07/20 Javascript
默认浏览器设置及vue自动打开页面的方法
2018/09/21 Javascript
对 Vue-Router 进行单元测试的方法
2018/11/05 Javascript
如何正确理解vue中的key详解
2019/11/02 Javascript
js实现随机点名
2021/01/19 Javascript
Python和JavaScript间代码转换的4个工具
2016/02/22 Python
Python编程实现及时获取新邮件的方法示例
2017/08/10 Python
Python探索之爬取电商售卖信息代码示例
2017/10/27 Python
python编程实现12306的一个小爬虫实例
2017/12/27 Python
python读取文本绘制动态速度曲线
2018/06/21 Python
pycharm创建scrapy项目教程及遇到的坑解析
2019/08/15 Python
python创建学生成绩管理系统
2019/11/22 Python
css3+jq创作含苞待放的荷花
2014/02/20 HTML / CSS
澳大利亚儿童鞋在线:The Trybe
2019/07/16 全球购物
大学生就业自我推荐信
2014/05/10 职场文书
市场营销专业毕业生求职信
2014/07/21 职场文书
我的中国梦演讲稿300字
2014/08/19 职场文书
乡镇创先争优活动总结
2014/08/28 职场文书
不同意离婚答辩状
2015/05/22 职场文书
大学生社会实践感想
2015/08/11 职场文书
vue实现同时设置多个倒计时
2021/05/20 Vue.js
redis不能访问本机真实ip地址的解决方案
2021/07/07 Redis
Canvas如何做个雪花屏版404的实现
2021/09/25 HTML / CSS