TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 文件重命名工具代码
Jul 26 Python
Python随手笔记之标准类型内建函数
Dec 02 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
Python编程生成随机用户名及密码的方法示例
May 05 Python
Python遍历numpy数组的实例
Apr 04 Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 Python
python 给DataFrame增加index行名和columns列名的实现方法
Jun 08 Python
Python3.5 处理文本txt,删除不需要的行方法
Dec 10 Python
通过PHP与Python代码对比的语法差异详解
Jul 10 Python
Python FFT合成波形的实例
Dec 04 Python
python3利用Axes3D库画3D模型图
Mar 25 Python
Python基础之数据结构详解
Apr 28 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
让PHP支持页面回退的两种方法
2008/01/10 PHP
PHP的范围解析操作符(::)的含义分析说明
2011/07/03 PHP
CI框架无限级分类+递归的实现代码
2016/11/01 PHP
PHP使用XMLWriter读写xml文件操作详解
2018/07/31 PHP
PHP中有关长整数的一些操作教程
2019/09/11 PHP
JavaScript 应用技巧集合[推荐]
2009/08/30 Javascript
ECMAScript 创建自己的js类库
2012/11/22 Javascript
给Flash加一个超链接(推荐使用透明层)兼容主流浏览器
2013/06/09 Javascript
js判断横竖屏及禁止浏览器滑动条示例
2014/04/29 Javascript
jQuery选择器源码解读(二):select方法
2015/03/31 Javascript
javascript文本模板用法实例
2015/07/31 Javascript
jquery实现九宫格大转盘抽奖
2015/11/13 Javascript
轻松实现Bootstrap图片轮播
2020/04/20 Javascript
基于RequireJS和JQuery的模块化编程——常见问题全面解析
2016/04/14 Javascript
js实现添加可信站点、修改activex安全设置,禁用弹出窗口阻止程序
2016/08/17 Javascript
js实现悬浮窗效果(支持拖动)
2017/03/09 Javascript
vue axios数据请求及vue中使用axios的方法
2018/09/10 Javascript
跟老齐学Python之类的细节
2014/10/13 Python
常见的python正则用法实例讲解
2016/06/21 Python
Python 3实战爬虫之爬取京东图书的图片详解
2017/10/09 Python
Python实现自定义函数的5种常见形式分析
2018/06/16 Python
python 利用浏览器 Cookie 模拟登录的用户访问知乎的方法
2019/07/11 Python
如何基于python生成list的所有的子集
2019/11/11 Python
Python箱型图处理离群点的例子
2019/12/09 Python
Pycharm debug调试时带参数过程解析
2020/02/03 Python
python如何对链表操作
2020/10/10 Python
Django如何实现防止XSS攻击
2020/10/13 Python
html5使用canvas压缩图片的示例代码
2018/09/11 HTML / CSS
世界上最大的售后摩托车零配件超市:J&P Cycles
2017/12/08 全球购物
英国第一的滑雪服装和装备零售商:Snow+Rock
2020/02/01 全球购物
新东方旗下远程教育网站:新东方在线
2020/03/19 全球购物
国贸专业自荐信范文
2014/03/02 职场文书
幼儿园六一亲子活动方案
2014/08/26 职场文书
机械制造专业大学生自我鉴定
2014/09/19 职场文书
酒店员工辞职信范文
2015/02/28 职场文书
我在伊朗长大观后感
2015/06/16 职场文书