TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从源码解析Python的Flask框架中request对象的用法
Jun 02 Python
Python网络编程使用select实现socket全双工异步通信功能示例
Apr 09 Python
python实现音乐下载器
Apr 15 Python
Python 通过requests实现腾讯新闻抓取爬虫的方法
Feb 22 Python
django query模块
Apr 20 Python
基于多进程中APScheduler重复运行的解决方法
Jul 22 Python
django重新生成数据库中的某张表方法
Aug 28 Python
Django日志及中间件模块应用案例
Sep 10 Python
Python3 用matplotlib绘制sigmoid函数的案例
Dec 11 Python
Python-typing: 类型标注与支持 Any类型详解
May 10 Python
解决pycharm安装scrapy DLL load failed:找不到指定的程序的问题
Jun 08 Python
4种方法python批量修改替换列表中元素
Apr 07 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
php 文件夹删除、php清除缓存程序
2009/08/25 PHP
PHP 将逗号、空格、回车分隔的字符串转换为数组的函数
2012/06/07 PHP
php获取参数的几种方法总结
2014/02/18 PHP
PHP实现把MySQL数据库导出为.sql文件实例(仿PHPMyadmin导出功能)
2014/05/10 PHP
ThinkPHP做文字水印时提示call an undefined function exif_imagetype()解决方法
2014/10/30 PHP
总结对比php中的多种序列化
2016/08/28 PHP
Zend Framework数据库操作技巧总结
2017/02/18 PHP
MAC下通过改apache配置文件切换php多版本的方法
2017/04/26 PHP
php装饰者模式简单应用案例分析
2019/10/23 PHP
使用CSS3实现字体颜色渐变的实现
2021/03/09 HTML / CSS
ajax中get和post的说明及使用与区别
2012/12/23 Javascript
addEventListener()第三个参数useCapture (Boolean)详细解析
2013/11/07 Javascript
javascript内存管理详细解析
2013/11/11 Javascript
扩展JS Date对象时间格式化功能的小例子
2013/12/02 Javascript
JavaScript中的数组操作介绍
2014/12/30 Javascript
jQuery post数据至ashx实例详解
2016/11/18 Javascript
JS得到当前时间的方法示例
2017/03/24 Javascript
微信小程序显示下拉列表功能【附源码下载】
2017/12/12 Javascript
详解mpvue scroll-view自动回弹bug解决方案
2018/10/01 Javascript
vue 内联样式style中的background用法说明
2020/08/05 Javascript
vue实现简单加法计算器
2020/10/22 Javascript
windows系统中python使用rar命令压缩多个文件夹示例
2014/05/06 Python
python中日志logging模块的性能及多进程详解
2017/07/18 Python
Python 加密与解密小结
2018/12/06 Python
python读取Kafka实例
2019/12/23 Python
Django choices下拉列表绑定实例
2020/03/13 Python
html5基础标签(html5视频标签 html5新标签用法)
2013/12/30 HTML / CSS
如何让Java程序执行效率更高
2014/06/25 面试题
一些网络技术方面的面试题
2014/05/01 面试题
给排水工程师岗位职责
2013/11/21 职场文书
开工庆典邀请函范文
2014/01/16 职场文书
保安的辞职报告怎么写
2014/01/20 职场文书
学习计划书怎么写
2014/09/15 职场文书
先进个人自荐书
2015/03/06 职场文书
分享15个Webpack实用的插件!!!
2021/03/31 Javascript
CSS+HTML 实现顶部导航栏功能
2021/08/30 HTML / CSS