将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django中实现一个高性能计数器(Counter)实例
Jul 09 Python
Python中获取网页状态码的两个方法
Nov 03 Python
在Django的URLconf中进行函数导入的方法
Jul 18 Python
深入解答关于Python的11道基本面试题
Apr 01 Python
解决Django migrate No changes detected 不能创建表的问题
May 27 Python
解决Python安装时报缺少DLL问题【两种解决方法】
Jul 15 Python
使用python实现对元素的长截图功能
Nov 14 Python
Python中sys模块功能与用法实例详解
Feb 26 Python
Python基于codecs模块实现文件读写案例解析
May 11 Python
Python selenium使用autoIT上传附件过程详解
May 26 Python
详解Python函数print用法
Jun 18 Python
使用scrapy实现增量式爬取方式
Jun 21 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
重料打造自己的“宝马”---第三代
2021/03/02 无线电
PHP实现视频文件上传完整实例
2014/08/28 PHP
php正则preg_replace_callback函数用法实例
2015/06/01 PHP
TP(thinkPHP)框架多层控制器和多级控制器的使用示例
2018/06/13 PHP
非常不错的功能强大代码简单的管理菜单美化版
2008/07/09 Javascript
用JS控制回车事件的代码
2011/02/20 Javascript
js Event对象的5种坐标
2011/09/12 Javascript
javascript动态加载实现方法一
2012/08/22 Javascript
CSS+jQuery实现的一个放大缩小动画效果
2013/09/24 Javascript
JavaScript中的操作符==与===介绍
2014/12/31 Javascript
javascript异步编程代码书写规范Promise学习笔记
2015/02/11 Javascript
jQuery模拟原生态App上拉刷新下拉加载更多页面及原理
2015/08/10 Javascript
Vue如何引入远程JS文件
2017/04/20 Javascript
three.js中3D视野的缩放实现代码
2017/11/16 Javascript
Vue隐藏显示、只读实例代码
2018/07/18 Javascript
Vue中android4.4不兼容问题的解决方法
2018/09/04 Javascript
微信小程序事件对象中e.target和e.currentTarget的区别详解
2019/05/08 Javascript
jQuery中使用validate插件校验表单功能
2019/05/24 jQuery
判断JavaScript中的两个变量是否相等的操作符
2019/12/21 Javascript
[33:15]2018DOTA2亚洲邀请赛3月30日 小组赛B组 VP VS Mineski
2018/03/31 DOTA
在Python web中实现验证码图片代码分享
2017/11/09 Python
Python访问MongoDB,并且转换成Dataframe的方法
2018/10/15 Python
python中partial()基础用法说明
2018/12/30 Python
pyqt5实现按钮添加背景图片以及背景图片的切换方法
2019/06/13 Python
python实现根据文件格式分类
2019/10/31 Python
记录模型训练时loss值的变化情况
2020/06/16 Python
CSS3教程(5):网页背景图片
2009/04/02 HTML / CSS
美国批发供应商:Kole Imports
2019/04/10 全球购物
《维生素c的故事》教学反思
2014/02/18 职场文书
党员干部承诺书范文
2014/03/25 职场文书
成都人事代理协议书
2014/10/25 职场文书
2015年教研员工作总结
2015/05/26 职场文书
团干部培训班心得体会
2016/01/06 职场文书
用Python简陋模拟n阶魔方
2021/04/17 Python
纯html+css实现奥运五环的示例代码
2021/08/02 HTML / CSS
Java异常体系非正常停止和分类
2022/06/14 Java/Android