将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据库操作常用功能使用详解(创建表/插入数据/获取数据)
Dec 06 Python
在django中使用自定义标签实现分页功能
Jul 04 Python
python3 破解 geetest(极验)的滑块验证码功能
Feb 24 Python
使用Django和Python创建Json response的方法
Mar 26 Python
Python解析并读取PDF文件内容的方法
May 08 Python
Python3实现爬取简书首页文章标题和文章链接的方法【测试可用】
Dec 11 Python
python区块及区块链的开发详解
Jul 03 Python
matplotlib实现显示伪彩色图像及色度条
Dec 07 Python
通过python连接Linux命令行代码实例
Feb 18 Python
Keras 中Leaky ReLU等高级激活函数的用法
Jul 05 Python
基于python实现可视化生成二维码工具
Jul 08 Python
Python一些基本的图像操作和处理总结
Jun 23 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
BBS(php & mysql)完整版(六)
2006/10/09 PHP
基于OpenCart 开发支付宝,财付通,微信支付参数错误问题
2015/10/01 PHP
php实现查询功能(数据访问)
2017/05/23 PHP
PHP实现的只保留字符串首尾字符功能示例【隐藏部分字符串】
2019/03/11 PHP
海量经典的jQuery插件集合
2010/01/12 Javascript
jQuery .attr()和.removeAttr()方法操作元素属性示例
2013/07/16 Javascript
JavaScript中使用sencha gridpanel 编辑单元格、改变单元格颜色
2015/11/26 Javascript
深入理解JS继承和原型链的问题
2016/12/17 Javascript
React 组件渲染和更新的实现代码示例
2019/02/21 Javascript
jQuery实现input[type=file]多图预览上传删除等功能
2019/08/02 jQuery
VueX模块的具体使用(小白教程)
2020/06/05 Javascript
JavaScript语句错误throw、try及catch实例解析
2020/08/18 Javascript
Python机器学习之SVM支持向量机
2017/12/27 Python
解决python 输出是省略号的问题
2018/04/19 Python
Django 浅谈根据配置生成SQL语句的问题
2018/05/29 Python
python 循环读取txt文档 并转换成csv的方法
2018/10/26 Python
Ranorex通过Python将报告发送到邮箱的方法
2020/01/12 Python
python如何提升爬虫效率
2020/09/27 Python
python 实现超级玛丽游戏
2020/11/25 Python
pycharm Tab键设置成4个空格的操作
2021/02/26 Python
Nixon手表英国官网:美国尼克松手表品牌
2020/02/10 全球购物
Blank NYC官网:夹克、牛仔裤等
2020/12/16 全球购物
应用数学自荐书范文
2013/11/24 职场文书
事务机电主管工作职责
2014/02/25 职场文书
《赠汪伦》教学反思
2014/04/12 职场文书
创业融资计划书
2014/04/25 职场文书
幼儿园安全生产月活动总结
2014/07/05 职场文书
致百米运动员广播稿5篇
2014/10/13 职场文书
走进毛泽东观后感
2015/06/04 职场文书
学习焦裕禄观后感
2015/06/09 职场文书
入党转正申请书范文
2019/05/20 职场文书
MySQL如何解决幻读问题
2021/08/07 MySQL
mysql中整数数据类型tinyint详解
2021/12/06 MySQL
分享CSS盒子模型隐藏的几种方式
2022/02/28 HTML / CSS
springboot用户数据修改的详细实现
2022/04/06 Java/Android
JS中forEach()、map()、every()、some()和filter()的用法
2022/05/11 Javascript