tensorflow没有output结点,存储成pb文件的例子


Posted in Python onJanuary 04, 2020

Tensorflow中保存成pb file 需要 使用函数

graph_util.convert_variables_to_constants(sess, sess.graph_def,

output_node_names=[]) []中需要填写你需要保存的结点。如果保存的结点在神经网络中没有被显示定义该怎么办?

例如我使用了tf.contrib.slim或者keras,在tf的高层很多情况下都会这样。

在写神经网络时,只需要简单的一层层传导,一个slim.conv2d层就包含了kernal,bias,activation function,非常的方便,好处是网络结构一目了然,坏处是什么呢?

tensorflow没有output结点,存储成pb文件的例子

在尝试保存pb的 output node names时,需要将最后的输出结点保存下来,与这个结点相关的,从输入开始,经过层层传递的嵌套函数或者操作的相关结点,都会被保存,但无效的例如 计算准确率,计算loss等,就可以省略了,因为保存的pb主要是用来做预测的。

在准备查看所有的结点名称并选取保存时,发现scope "local3"里面仅有相关的weights 和biases,这两个是单独存在的,即保存这两个参数并没有任何意义。

tensorflow没有output结点,存储成pb文件的例子

那么这时候有两种解决办法:

方法一:

graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=[var.name[:-2] for var in tf.global_variables()])

那么这个的意思是所有的variable的都被保存下来 但函数中要求的是 node name 我们通过 global_variables获得的是 变量名 并不是 节点名

(例如 output:0 就是变量名,又叫tensor name)

output就是 node name了。

在tensorboard中可以一窥究竟

tensorflow没有output结点,存储成pb文件的例子

通过这样 也可以将 所有的变量全部保存下来(但是你并不能使用,是因为你的output并没有名字,所以你不可以通过常用的sess.graph.get_tensor_by_name来使用)

方法二:

那就是直接改写神经网络了....当然了还是比较简单的,只要改写最后一个,改写成output即可,tensorflow中无论是 变量、操作op、函数、都可以命名,那么这个地方是一个简单的全连接,仅需要将weights*net(上一层的输出) +bias 即可,我们只要将bias相加的结果命名为 ouput即可:

with tf.name_scope('local3'):
 
  local3_weights = tf.Variable(tf.truncated_normal([4096, self.output_size], stddev=0.1))
 
  local3_bias = tf.Variable(tf.constant(0.1, shape=[self.output_size]))
 
result = tf.add(tf.matmul(net, local3_weights), local3_bias, name="output")

这样将上述的convert_variables_to_constants中的output_node_names只需要填写一个['output']即可,因为这一个output结点,需要从input开始,将所有的神经网络前向传播的操作和参数全部保存下来,因此保存的结点数量 和 方法一保存的结点数量是一样的(console显示都是 convert 24)。

完整的pb保存为:(我是将ckpt读入进来,然后存成pb的)

from tensorflow.python.platform import gfile
 
 
 
load_ckpt():
 
  path = './data/output/loss1.0/'
 
  print("read from ckpt")
 
  ckpt = tf.train.get_checkpoint_state(path)
 
  saver = tf.train.Saver()
 
  saver.restore(sess, ckpt.model_checkpoint_path)
 
 
 
def write2pb_file():
 
  constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,
 
    output_node_names=["output"])
 
  with tf.gfile.GFile(path+'loss1.0.pb', mode='wb') as f:
 
  f.write(constant_graph.SerializeToString())
 
  print("Model is saved as " + path+'loss1.0.pb')
 
 
 
def main():
 
  load_ckpt()
 
  write2pb_file()

如果是简单的直接保存,那就更简单了。

pb文件的read,很多人会将一个net写成一个类,在引入的时候会将新建这个类,然后读入ckpt文件,这完全没有问题,但是在读取pb时,就会发生问题,因为pb中已经包含了图与参数,引入时会创建一个默认的图,但是net类中自己也会创建一个图,那么这时候你运行程序,参数其实并没有使用.pb的文件。

所以我们不能创建net类,然后直接读入.pb文件,对.pb文件,通过如下代码,获取.pb的graph中的输入和输出。

self.output = self.sess.graph.get_tensor_by_name("output:0")
 
self.input = self.sess.graph.get_tensor_by_name("images:0")

注意此时要加:0 因为你获取的不再是结点了,而是一个真实的变量,我的理解是,结点相当于一个类,:0是对象,默认初始化值就是对象的初始化。

然后就可以通过self.sess.run(self.output(feed_dict={self.input: your_input})))运行你的网络了!

以上这篇tensorflow没有output结点,存储成pb文件的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现基于两张图片生成圆角图标效果的方法
Mar 26 Python
python统计日志ip访问数的方法
Jul 06 Python
Python使用smtp和pop简单收发邮件完整实例
Jan 09 Python
Python中捕获键盘的方式详解
Mar 28 Python
Python实现性能自动化测试竟然如此简单
Jul 30 Python
Python 多线程其他属性以及继承Thread类详解
Aug 28 Python
python 内置函数汇总详解
Sep 16 Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 Python
windows支持哪个版本的python
Jul 03 Python
python环境搭建和pycharm的安装配置及汉化详细教程(零基础小白版)
Aug 19 Python
selenium+超级鹰实现模拟登录12306
Jan 24 Python
python 实现IP子网计算
Feb 18 Python
TensorFlow查看输入节点和输出节点名称方式
Jan 04 #Python
根据tensor的名字获取变量的值方式
Jan 04 #Python
将tensorflow.Variable中的某些元素取出组成一个新的矩阵示例
Jan 04 #Python
tensorflow实现tensor中满足某一条件的数值取出组成新的tensor
Jan 04 #Python
对tensorflow中的strides参数使用详解
Jan 04 #Python
tensorflow之获取tensor的shape作为max_pool的ksize实例
Jan 04 #Python
TensorFlow tf.nn.max_pool实现池化操作方式
Jan 04 #Python
You might like
PHP Zip解压 文件在线解压缩的函数代码
2010/05/26 PHP
PHP Parse Error: syntax error, unexpected $end 错误的解决办法
2012/06/05 PHP
str_replace只替换一次字符串的方法
2013/04/09 PHP
laravel请求参数校验方法
2019/10/10 PHP
如何在Laravel之外使用illuminate组件详解
2020/09/20 PHP
Google Map Api和GOOGLE Search Api整合实现代码
2009/07/18 Javascript
基于jQuery的history历史记录插件
2010/12/11 Javascript
封装了一个js图片轮换效果的函数
2011/09/28 Javascript
用jquery实现点击栏目背景色改变
2012/12/10 Javascript
JQuery表格内容过滤的实现方法
2013/07/05 Javascript
JS随机漂浮广告代码具体实例
2013/11/19 Javascript
关闭ie窗口清除Session的解决方法
2014/01/10 Javascript
javascript创建createXmlHttpRequest对象示例代码
2014/02/10 Javascript
当某个文本框成为焦点时即清除文本框内容
2014/04/28 Javascript
JavaScript sup方法入门实例(把字符串显示为上标)
2014/10/20 Javascript
js中的json对象详细介绍
2014/10/29 Javascript
jQuery中scrollLeft()方法用法实例
2015/01/16 Javascript
NodeJS处理Express中异步错误
2017/03/26 NodeJs
浅析JS中常用类型转换及运算符表达式
2017/07/23 Javascript
微信小程序之滚动视图容器的实现方法
2017/09/26 Javascript
在 React、Vue项目中使用SVG的方法
2018/02/09 Javascript
详解关于Vue2.0路由开启keep-alive时需要注意的地方
2018/09/18 Javascript
node.js开发辅助工具nodemon安装与配置详解
2020/02/06 Javascript
JavaScript实现指定数量的并发限制的示例代码
2020/03/10 Javascript
layui实现显示数据表格、搜索和修改功能示例
2020/06/03 Javascript
Vue单文件组件开发实现过程详解
2020/07/30 Javascript
[43:53]OG vs EG 2019国际邀请赛淘汰赛 胜者组 BO3 第三场 8.22
2019/09/05 DOTA
Python 命令行参数sys.argv
2008/09/06 Python
详解在Python程序中解析并修改XML内容的方法
2015/11/16 Python
使用Python编写爬虫的基本模块及框架使用指南
2016/01/20 Python
python实现简易动态时钟
2018/11/19 Python
pandas DataFrame索引行列的实现
2019/06/04 Python
大四学年自我鉴定
2013/11/13 职场文书
Oracle设置DB、监听和EM开机启动的方法
2021/04/25 Oracle
python实现自动清理文件夹旧文件
2021/05/10 Python
MySQL中正则表达式(REGEXP)使用详解
2022/07/07 MySQL