Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python Django模板的使用方法(图文)
Nov 04 Python
图文讲解选择排序算法的原理及在Python中的实现
May 04 Python
Python读取word文本操作详解
Jan 22 Python
通过cmd进入python的实例操作
Jun 26 Python
Python Django 添加首页尾页上一页下一页代码实例
Aug 21 Python
scikit-learn线性回归,多元回归,多项式回归的实现
Aug 29 Python
Python 单例设计模式用法实例分析
Sep 23 Python
浅谈Python中threading join和setDaemon用法及区别说明
May 02 Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 Python
Python调用ffmpeg开源视频处理库,批量处理视频
Nov 16 Python
Python数据分析入门之数据读取与存储
May 13 Python
Python+Tkinter制作专属图形化界面
Apr 01 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
Amazon Prime Video平台《无限住人 -IMMORTAL-》2020年开始TV放送!
2020/03/06 日漫
Web 前端设计模式--Dom重构 提高显示性能
2010/10/22 Javascript
用dtree实现树形菜单 dtree使用说明
2011/10/17 Javascript
基于JavaScript 数据类型之Boolean类型分析介绍
2013/04/19 Javascript
javascript实现右侧弹出“分享到”窗口效果
2016/02/01 Javascript
Javascript数组Array方法解读
2016/03/13 Javascript
bootstrap datetimepicker实现秒钟选择下拉框
2017/01/05 Javascript
基于Vue的文字跑马灯组件(npm 组件包)
2017/05/24 Javascript
微信小程序实现图片放大预览功能
2020/10/22 Javascript
JS处理一些简单计算题
2018/02/24 Javascript
Angular5升级RxJS到5.5.3报错:EmptyError: no elements in sequence的解决方法
2018/04/09 Javascript
ionic使用angularjs表单验证(模板验证)
2018/12/12 Javascript
微信小程序动态添加view组件的实例代码
2019/05/23 Javascript
详解package.json版本号规则
2019/08/01 Javascript
Python实现简易端口扫描器代码实例
2017/03/15 Python
python下10个简单实例代码
2017/11/15 Python
Tornado高并发处理方法实例代码
2018/01/15 Python
在Python中画图(基于Jupyter notebook的魔法函数)
2019/10/28 Python
Python爬虫爬取电影票房数据及图表展示操作示例
2020/03/27 Python
基于打开pycharm有带图片md文件卡死问题的解决
2020/04/24 Python
Django ForeignKey与数据库的FOREIGN KEY约束详解
2020/05/20 Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
2020/06/23 Python
获取CSDN文章内容并转换为markdown文本的python
2020/09/06 Python
英国著名的化妆品折扣网站:Allbeauty.com
2016/07/21 全球购物
美国领先的精品家居照明和装饰产品在线零售商:LightsOnline.com
2018/01/23 全球购物
Melissa鞋马来西亚官方网站:MDreams马来西亚
2018/04/05 全球购物
简单的大学生自我鉴定
2014/02/18 职场文书
创新型城市实施方案
2014/03/06 职场文书
委托公证书
2014/04/08 职场文书
活动倡议书范文
2014/05/13 职场文书
解除劳动合同证明书模板
2014/11/20 职场文书
党风廉政建设个人总结
2015/03/06 职场文书
2016年寒假政治学习心得体会
2015/10/09 职场文书
小学四年级班主任工作经验交流材料
2015/11/02 职场文书
小学体育队列队形教学反思
2016/02/16 职场文书
HTML页面点击按钮关闭页面的多种方式
2022/12/24 HTML / CSS