将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的RSS阅读器实例
Jul 25 Python
python中异常捕获方法详解
Mar 03 Python
Python实现批量压缩图片
Jan 25 Python
Python操作Redis之设置key的过期时间实例代码
Jan 25 Python
PyQt5每天必学之工具提示功能
Apr 19 Python
django项目搭建与Session使用详解
Oct 10 Python
Python Pexpect库的简单使用方法
Jan 29 Python
如何在Python中实现goto语句的方法
May 18 Python
多版本python的pip 升级后, pip2 pip3 与python版本失配解决方法
Sep 11 Python
Python日期格式和字符串格式相互转换的方法
Feb 18 Python
Python 如何创建一个线程池
Jul 28 Python
python标准库ElementTree处理xml
May 20 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
谈谈PHP语法(4)
2006/10/09 PHP
轻松修复Discuz!数据库
2008/05/03 PHP
怎样去阅读一份php源代码
2009/08/21 PHP
30 个很棒的PHP开源CMS内容管理系统小结
2011/10/14 PHP
关于zend studio 出现乱码问题的总结
2013/06/23 PHP
PHP5中GD库生成图形验证码(有汉字)
2013/07/28 PHP
Nginx下配置codeigniter框架方法
2015/04/07 PHP
PHP防止刷新重复提交页面的示例代码
2015/11/11 PHP
PHP微信开发之微信录音临时转永久存储
2018/01/26 PHP
Jquery iframe内部出滚动条
2010/02/11 Javascript
浅析js中2个等号与3个等号的区别
2013/08/06 Javascript
在父页面调用子页面的JS方法
2013/09/29 Javascript
javascript的理解及经典案例分析
2016/05/20 Javascript
jquery ezUI 双击行记录弹窗查看明细的实现方法
2016/06/01 Javascript
VC调用javascript的几种方法(推荐)
2016/08/09 Javascript
JavaScript实现审核流程状态的动态显示进度条
2017/03/15 Javascript
JS简单获取当前日期时间的方法(如:2017-03-29 11:41:10 星期四)
2017/03/29 Javascript
利用JS hash制作单页Web应用的方法详解
2017/10/10 Javascript
原生JS实现自定义下拉单选选择框功能
2018/10/12 Javascript
javascript实现评分功能
2020/06/24 Javascript
微信小程序视频弹幕发送功能的实现
2020/12/28 Javascript
python paramiko实现ssh远程访问的方法
2013/12/03 Python
Python使用py2exe打包程序介绍
2014/11/20 Python
python3实现ftp服务功能(服务端 For Linux)
2017/03/24 Python
Python内建函数之raw_input()与input()代码解析
2017/10/26 Python
python 每天如何定时启动爬虫任务(实现方法分享)
2018/05/21 Python
python实现字符串加密 生成唯一固定长度字符串
2019/03/22 Python
python twilio模块实现发送手机短信功能
2019/08/02 Python
Django之路由层的实现
2019/09/09 Python
Python实现点云投影到平面显示
2020/01/18 Python
如何搭建pytorch环境的方法步骤
2020/05/06 Python
html5触摸事件判断滑动方向的实现
2018/06/05 HTML / CSS
腾讯技术类校园招聘笔试试题
2014/05/06 面试题
一家外企的面试题目(C/C++面试题,C语言面试题)
2014/03/24 面试题
酒店总经理岗位职责
2014/03/17 职场文书
市场营销毕业生自荐信范文
2014/04/01 职场文书