将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序员鲜为人知但你应该知道的17个问题
Jun 04 Python
跟老齐学Python之变量和参数
Oct 10 Python
编写Python爬虫抓取豆瓣电影TOP100及用户头像的方法
Jan 20 Python
Python变量和数据类型详解
Feb 15 Python
Python Selenium Cookie 绕过验证码实现登录示例代码
Apr 10 Python
Python字典中的键映射多个值的方法(列表或者集合)
Oct 17 Python
在python中pandas读文件,有中文字符的方法
Dec 12 Python
python迭代器常见用法实例分析
Nov 22 Python
使用python3批量下载rbsp数据的示例代码
Dec 20 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 Python
pycharm最新激活码有效期至2100年(亲测可用)
Feb 05 Python
Django后端按照日期查询的方法教程
Feb 28 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
PHP数据类型的总结分析
2013/06/13 PHP
php获得刚插入数据的id 的几种方法总结
2018/05/31 PHP
Laravel框架实现的批量删除功能示例
2019/01/16 PHP
Laravel框架基于ajax和layer.js实现无刷新删除功能示例
2019/01/17 PHP
php设计模式之适配器模式原理、用法及注意事项详解
2019/09/24 PHP
解决JS浮点数运算出现Bug的方法
2013/03/12 Javascript
jquery控制左右箭头滚动图片列表的实例
2013/05/20 Javascript
EasyUI实现二级页面的内容勾选的方法
2015/03/01 Javascript
基于JavaScript怎么实现让歌词滚动播放
2015/11/03 Javascript
jQuery图片旋转插件jQueryRotate.js用法实例(附demo下载)
2016/01/21 Javascript
深入理解JavaScript中的call、apply、bind方法的区别
2016/05/30 Javascript
JS中跨页面调用变量和函数的方法(例如a.js 和 b.js中互相调用)
2016/11/01 Javascript
浅谈在koa2中实现页面渲染的全局数据
2017/10/09 Javascript
axios中cookie跨域及相关配置示例详解
2017/12/20 Javascript
基于AngularJs select绑定数字类型的问题
2018/10/08 Javascript
详解js访问对象的属性和方法
2018/10/25 Javascript
JS+canvas画布实现炫酷的旋转星空效果示例
2019/02/13 Javascript
在NPM发布自己造的轮子的方法步骤
2019/03/09 Javascript
浅谈KOA2 Restful方式路由初探
2019/03/14 Javascript
Vue实现PC端靠边悬浮球的代码
2020/05/09 Javascript
jQuery实现简单评论功能
2020/08/19 jQuery
高性能web服务器框架Tornado简单实现restful接口及开发实例
2014/07/16 Python
python定时检查某个进程是否已经关闭的方法
2015/05/20 Python
python urllib爬取百度云连接的实例代码
2017/06/19 Python
总结python中pass的作用
2019/02/27 Python
tensorflow对图像进行拼接的例子
2020/02/05 Python
python将logging模块封装成单独模块并实现动态切换Level方式
2020/05/12 Python
凯撒娱乐:Caesars Entertainment
2018/02/23 全球购物
两道JAVA笔试题
2016/09/14 面试题
蔬菜基地的创业计划书
2014/01/06 职场文书
教学大赛获奖感言
2014/01/15 职场文书
敬老院标语
2014/06/27 职场文书
优秀教研组申报材料
2014/12/26 职场文书
《比尾巴》教学反思
2016/02/24 职场文书
Vue3中toRef与toRefs的区别
2022/03/24 Vue.js
vue 把二维或多维数组转一维数组
2022/04/24 Vue.js