将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于mysql实现的简单队列以及跨进程锁实例详解
Jul 07 Python
python基础教程之对象和类的实际运用
Aug 29 Python
Python Tkinter基础控件用法
Sep 03 Python
python判断完全平方数的方法
Nov 13 Python
详解python--模拟轮盘抽奖游戏
Apr 12 Python
python async with和async for的使用
Jun 20 Python
Django 查询数据库并返回页面的例子
Aug 12 Python
解决Django migrate不能发现app.models的表问题
Aug 31 Python
tensorflow 实现自定义layer并添加到计算图中
Feb 04 Python
Pytest参数化parametrize使用代码实例
Feb 22 Python
python用opencv完成图像分割并进行目标物的提取
May 25 Python
Python3以GitHub为例来实现模拟登录和爬取的实例讲解
Jul 30 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
dedecms中显示数字验证码的修改方法
2007/03/21 PHP
windwos下使用php连接oracle数据库的过程分享
2014/05/26 PHP
PHP获取ip对应地区和使用网络类型的方法
2015/03/11 PHP
JS 控制CSS样式表
2009/08/20 Javascript
js apply/call/caller/callee/bind使用方法与区别分析
2009/10/28 Javascript
js关闭当前页面(窗口)的几种方式总结
2013/03/05 Javascript
Javascript加载速度慢的解决方案
2014/03/11 Javascript
js实现可得到不同颜色值的颜色选择器实例
2015/02/28 Javascript
js实现基于正则表达式的轻量提示插件
2015/08/29 Javascript
js无法获取到html标签的属性的解决方法
2016/07/26 Javascript
利用Query+bootstrap和js两种方式实现日期选择器
2017/01/10 Javascript
利用javascript实现的三种图片放大镜效果实例(附源码)
2017/01/23 Javascript
js实现3d悬浮效果
2017/02/16 Javascript
Vue.js项目中管理每个页面的头部标签的两种方法
2018/06/25 Javascript
layui表格数据重载
2019/07/27 Javascript
python使用在线API查询IP对应的地理位置信息实例
2014/06/01 Python
全面解读Python Web开发框架Django
2014/06/30 Python
python分割文件的常用方法
2014/11/01 Python
解决Python 爬虫URL中存在中文或特殊符号无法请求的问题
2018/05/11 Python
用Python写一个自动木马程序
2019/09/17 Python
Python利用全连接神经网络求解MNIST问题详解
2020/01/14 Python
CSS3条纹背景制作的实战攻略
2016/05/31 HTML / CSS
CSS3实现莲花绽放的动画效果
2020/11/06 HTML / CSS
瑞典时尚服装购物网站:Miinto.se
2017/10/30 全球购物
精彩的大学生自我评价
2013/11/17 职场文书
文科教师毕业的自我评价
2014/01/16 职场文书
单位实习证明怎么写
2014/01/17 职场文书
中秋节超市促销方案
2014/01/30 职场文书
电力公司个人求职信范文
2014/02/04 职场文书
电子工程专业毕业生求职信
2014/03/14 职场文书
计算机应用专业毕业生求职信
2014/06/03 职场文书
道路运输企业安全生产责任书
2014/07/28 职场文书
教师个人自我剖析材料
2014/09/29 职场文书
2014年维修工作总结
2014/11/22 职场文书
Nginx速查手册及常见问题
2022/04/07 Servers
Python自动操作神器PyAutoGUI的使用教程
2022/06/16 Python