将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 编程之twisted详解及简单实例
Jan 28 Python
Python异常对代码运行性能的影响实例解析
Feb 08 Python
解决python "No module named pip" 的问题
Oct 13 Python
Python编程深度学习计算库之numpy
Dec 28 Python
django 2.2和mysql使用的常见问题
Jul 18 Python
使用django和vue进行数据交互的方法步骤
Nov 11 Python
Python 基于wxpy库实现微信添加好友功能(简洁)
Nov 29 Python
pytorch 实现模型不同层设置不同的学习率方式
Jan 06 Python
python匿名函数lambda原理及实例解析
Feb 07 Python
500行python代码实现飞机大战
Apr 24 Python
Python安装Bs4的多种方法
Nov 28 Python
python实现MySQL指定表增量同步数据到clickhouse的脚本
Feb 26 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
PHP模拟SQL Server的两个日期处理函数
2006/10/09 PHP
深入理解PHP类的自动载入机制
2016/09/16 PHP
php模式设计之观察者模式应用实例分析
2019/09/25 PHP
PHP7 其他语言层面的修改
2021/03/09 PHP
JavaScript调用Activex控件的事件的实现方法
2010/04/11 Javascript
jquery autocomplete自动完成插件的的使用方法
2010/08/07 Javascript
简约JS日历控件 实例代码
2013/07/12 Javascript
javascript使用正则表达式实现去掉空格之后的字符
2015/02/15 Javascript
多种js图片预加载实现方式分享
2016/02/19 Javascript
浅谈jquery采用attr修改form表单enctype不起作用的问题
2016/11/25 Javascript
大白话讲解JavaScript的Promise
2017/04/06 Javascript
js实现图片上传预览原理分析
2017/07/13 Javascript
基于js中document.cookie全面解析
2017/09/14 Javascript
react 不用插件实现数字滚动的效果示例
2020/04/14 Javascript
Swiper实现导航栏滚动效果
2020/10/16 Javascript
Django小白教程之Django用户注册与登录
2016/04/22 Python
Python读取视频的两种方法(imageio和cv2)
2018/04/15 Python
Python读取excel指定列生成指定sql脚本的方法
2018/11/28 Python
Python子类继承父类构造函数详解
2019/02/19 Python
Python Web框架之Django框架cookie和session用法分析
2019/08/16 Python
在Python中等距取出一个数组其中n个数的实现方式
2019/11/27 Python
Django app配置多个数据库代码实例
2019/12/17 Python
Transpose 数组行列转置的限制方式
2020/02/11 Python
Linux系统下升级pip的完整步骤
2021/01/31 Python
python 列表推导和生成器表达式的使用
2021/02/01 Python
阿里巴巴美国:Alibaba美国
2019/11/24 全球购物
意大利单身交友网站:Meetic
2020/07/12 全球购物
写给女朋友的道歉信
2014/01/08 职场文书
班主任与学生安全责任书
2014/07/25 职场文书
项目申请汇报材料
2014/08/16 职场文书
初中家长意见
2015/06/03 职场文书
2016年公司新年寄语
2015/08/17 职场文书
HTML+CSS实现导航条下拉菜单的示例代码
2021/08/02 HTML / CSS
电脑关机速度很慢怎么办 提升电脑关机速度设置教程
2022/04/08 数码科技
Grafana可视化监控系统结合SpringBoot使用
2022/04/19 Redis
Golang jwt身份认证
2022/04/20 Golang