TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 切片和range()用法说明
Mar 24 Python
Python中使用 Selenium 实现网页截图实例
Jul 18 Python
python分析nignx访问日志脚本分享
Feb 26 Python
浅谈python字符串方法的简单使用
Jul 18 Python
对Python之gzip文件读写的方法详解
Feb 08 Python
Django给admin添加Action的步骤详解
May 01 Python
Python创建或生成列表的操作方法
Jun 19 Python
详解PyTorch手写数字识别(MNIST数据集)
Aug 16 Python
PyQt5实现登录页面
May 30 Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 Python
通过Python实现Payload分离免杀过程详解
Jul 13 Python
Python通过m3u8文件下载合并ts视频的操作
Apr 16 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
工厂模式在Zend Framework中应用介绍
2012/07/10 PHP
PHP模板引擎Smarty中的保留变量用法分析
2016/04/11 PHP
thinkPHP5框架路由常用知识点汇总
2019/09/15 PHP
jQuery中校验时间格式的正则表达式小结
2013/09/22 Javascript
javascript根据时间生成m位随机数最大13位
2014/10/30 Javascript
原生JS实现LOADING效果
2015/03/16 Javascript
微信小程序 Image API实例详解
2016/09/30 Javascript
微信小程序实现带刻度尺滑块功能
2017/03/29 Javascript
使用 Node.js 实现图片的动态裁切及算法实例代码详解
2018/09/29 Javascript
JS/HTML5游戏常用算法之碰撞检测 地图格子算法实例详解
2018/12/12 Javascript
JavaScript错误处理操作实例详解
2019/01/04 Javascript
layui实现把数据表格时间戳转换为时间格式的例子
2019/09/12 Javascript
jquery ajax 请求小技巧实例分析
2019/11/11 jQuery
基于JavaScript实现贪吃蛇游戏
2020/03/16 Javascript
记一次用ts+vuecli4重构项目的实现
2020/05/21 Javascript
js实现简单的无缝轮播效果
2020/09/05 Javascript
[45:56]Ti4正赛第一天 VG vs NEWBEE 3
2014/07/19 DOTA
python编程培训 python培训靠谱吗
2018/01/17 Python
python3第三方爬虫库BeautifulSoup4安装教程
2018/06/19 Python
python url 参数修改方法
2018/12/26 Python
Django项目后台不挂断运行的方法
2019/08/31 Python
python GUI库图形界面开发之PyQt5表单布局控件QFormLayout详细使用方法与实例
2020/03/06 Python
python反爬虫方法的优缺点分析
2020/11/25 Python
分享8款纯CSS3实现的搜索框功能
2017/09/14 HTML / CSS
中东地区最大的奢侈品市场:The Luxury Closet
2019/04/09 全球购物
学习十八大报告感言
2014/02/04 职场文书
市场营销毕业生自荐信范文
2014/04/01 职场文书
2015年元旦活动总结
2014/05/09 职场文书
村级换届选举方案
2014/05/10 职场文书
学校清明节活动总结
2014/07/04 职场文书
机关作风建设工作总结
2014/10/23 职场文书
企业整改报告范文
2014/11/08 职场文书
运动会闭幕式致辞
2015/07/29 职场文书
Python使用scapy模块发包收包
2021/05/07 Python
使用pandas或numpy处理数据中的空值(np.isnan()/pd.isnull())
2021/05/14 Python
一起来看看Vue的核心原理剖析
2022/03/24 Vue.js