TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python闭包实现计数器的方法
May 05 Python
Django中的“惰性翻译”方法的相关使用
Jul 27 Python
python 转换 Javascript %u 字符串为python unicode的代码
Sep 06 Python
Python读取excel中的图片完美解决方法
Jul 27 Python
Python并行分布式框架Celery详解
Oct 15 Python
Python发送邮件功能示例【使用QQ邮箱】
Dec 04 Python
python binascii 进制转换实例
Jun 12 Python
python3 自动识别usb连接状态,即对usb重连的判断方法
Jul 03 Python
详解PyTorch手写数字识别(MNIST数据集)
Aug 16 Python
Python3使用PySynth制作音乐的方法
Sep 09 Python
Python实现Excel文件的合并(以新冠疫情数据为例)
Mar 20 Python
Anaconda安装pytorch和paddle的方法步骤
Apr 03 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
PHP下MAIL的另一解决方案
2006/10/09 PHP
PHP调用存储过程返回值不一致问题的解决方法分析
2016/04/26 PHP
thinkPHP删除前弹出确认框的简单实现方法
2016/05/16 PHP
php设计模式之观察者模式实例详解【星际争霸游戏案例】
2020/03/30 PHP
用javascript动态调整iframe高度的代码
2007/04/10 Javascript
抽出www.templatemonster.com的鼠标悬停加载大图模板的代码
2007/07/11 Javascript
6个DIV 135或246间隔一秒轮番显示效果
2010/07/24 Javascript
jquery动态添加option示例
2013/12/30 Javascript
javascript 动态创建表格的2种方法总结
2015/03/04 Javascript
微信小程序 免费SSL证书https、TLS版本问题的解决办法
2016/12/14 Javascript
jQuery插件zTree实现获取一级节点数据的方法
2017/03/08 Javascript
vue 全选与反选的实现方法(无Bug 新手看过来)
2018/02/09 Javascript
JavaScript简单实现关键字文本搜索高亮显示功能示例
2018/07/25 Javascript
JS高级技巧(简洁版)
2018/07/29 Javascript
layui的表单验证支持ajax判断用户名是否重复的实例
2019/09/06 Javascript
javascript 原型与原型链的理解及实例分析
2019/11/23 Javascript
JavaScript组合设计模式--改进引入案例分析
2020/05/23 Javascript
python中getattr函数使用方法 getattr实现工厂模式
2014/01/20 Python
Python中__init__.py文件的作用详解
2016/09/18 Python
基于Python代码编辑器的选用(详解)
2017/09/13 Python
Python2与python3中 for 循环语句基础与实例分析
2017/11/20 Python
Python实现决策树C4.5算法的示例
2018/05/30 Python
python绘制已知点的坐标的直线实例
2019/07/04 Python
flask应用部署到服务器的方法
2019/07/12 Python
python各类经纬度转换的实例代码
2019/08/08 Python
pandas factorize实现将字符串特征转化为数字特征
2019/12/19 Python
Python基于类路径字符串获取静态属性
2020/03/12 Python
使用TensorBoard进行超参数优化的实现
2020/07/06 Python
详解Python流程控制语句
2020/10/28 Python
HTML5仿手机微信聊天界面
2016/03/18 HTML / CSS
团工委书记自荐书范文
2013/12/17 职场文书
计算机专业毕业生求职信分享
2013/12/24 职场文书
狼和鹿教学反思
2014/02/05 职场文书
公休请假条
2014/04/11 职场文书
导师评语大全
2014/04/26 职场文书
战友聚会策划方案
2014/06/13 职场文书