将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用urllib模块和pyquery实现阿里巴巴排名查询
Jan 16 Python
python实现的二叉树算法和kmp算法实例
Apr 25 Python
python中类的一些方法分析
Sep 25 Python
编写简单的Python程序来判断文本的语种
Apr 07 Python
使用IPython来操作Docker容器的入门指引
Apr 08 Python
Python下载指定页面上图片的方法
May 12 Python
Python内存管理方式和垃圾回收算法解析
Nov 11 Python
Python线性方程组求解运算示例
Jan 17 Python
Django框架设置cookies与获取cookies操作详解
May 27 Python
Python3开发实例之非关系型图数据库Neo4j安装方法及Python3连接操作Neo4j方法实例
Mar 18 Python
python爬虫用scrapy获取影片的实例分析
Nov 23 Python
python re模块常见用法例举
Mar 01 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
php 伪造本地文件包含漏洞的代码
2011/11/03 PHP
详解WordPress中用于合成数组的wp_parse_args()函数
2015/12/18 PHP
PHP正则匹配日期和时间(时间戳转换)的实例代码
2016/12/14 PHP
深入理解PHP的远程多会话调试
2017/09/21 PHP
PHP模型Model类封装数据库操作示例
2019/03/14 PHP
基于jquery的跨域调用文件
2010/11/19 Javascript
使用Jquery打造最佳用户体验的登录页面的实现代码
2011/07/08 Javascript
Eval and new funciton not the same thing
2012/12/27 Javascript
JavaScript操纵窗口的方法小结
2013/06/28 Javascript
解析javascript中鼠标滚轮事件
2015/05/26 Javascript
JS实现窗口加载时模拟鼠标移动的方法
2015/06/03 Javascript
javascript创建动态表单的方法
2015/07/25 Javascript
详解Bootstrap创建表单的三种格式(一)
2016/01/04 Javascript
AngularJS模块学习之Anchor Scroll
2016/01/19 Javascript
Easyui Treegrid改变默认图标的方法
2016/04/29 Javascript
AngularJS基础 ng-init 指令简单示例
2016/08/02 Javascript
js实现百度地图定位于地址逆解析,显示自己当前的地理位置
2016/12/08 Javascript
javascript 注释代码的几种方法总结
2017/01/04 Javascript
JavaScript实现大图轮播效果
2017/01/11 Javascript
vue+jquery+lodash实现滑动时顶部悬浮固定效果
2018/04/28 jQuery
详解Angular中通过$location获取地址栏的参数
2018/08/02 Javascript
微信小程序的tab选项卡的实现效果
2019/05/15 Javascript
微信小程序绘制图片发送朋友圈
2019/07/25 Javascript
解读Django框架中的低层次缓存API
2015/07/24 Python
Python基于PyGraphics包实现图片截取功能的方法
2017/12/21 Python
详解Python自建logging模块
2018/01/29 Python
浅谈PyQt5 的帮助文档查找方法,可以查看每个类的方法
2019/06/25 Python
python打印直角三角形与等腰三角形实例代码
2019/10/20 Python
如何利用pygame实现简单的五子棋游戏
2019/12/29 Python
在PyCharm中实现添加快捷模块
2020/02/12 Python
Python matplotlib可视化实例解析
2020/06/01 Python
Python偏函数实现原理及应用
2020/11/20 Python
主题教育活动总结
2014/05/05 职场文书
教师师德承诺书2016
2016/03/25 职场文书
Mysql数据库按时间点恢复实战记录
2021/06/30 MySQL
浅谈css实现背景颜色半透明的两种方法
2021/12/06 HTML / CSS