Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的爬虫包Beautiful Soup中用正则表达式来搜索
Jan 20 Python
Python和Java进行DES加密和解密的实例
Jan 09 Python
Python实现时钟显示效果思路详解
Apr 11 Python
django从请求到响应的过程深入讲解
Aug 01 Python
Python实现的字典排序操作示例【按键名key与键值value排序】
Dec 21 Python
python实现支付宝转账接口
May 07 Python
选择python进行数据分析的理由和优势
Jun 25 Python
python选取特定列 pandas iloc,loc,icol的使用详解(列切片及行切片)
Aug 06 Python
pandas 缺失值与空值处理的实现方法
Oct 12 Python
浅谈Keras参数 input_shape、input_dim和input_length用法
Jun 29 Python
python七种方法判断字符串是否包含子串
Aug 18 Python
pycharm 实现光标快速移动到括号外或行尾的操作
Feb 05 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
DOTA2 6.87版本后新眼位详解攻略
2020/04/20 DOTA
WINDOWS下php5.2.4+mysql6.0+apache2.2.4+ZendOptimizer-3.3.0配置
2008/03/28 PHP
php中iconv函数使用方法
2008/05/24 PHP
php 图片上传类代码
2009/07/17 PHP
php对称加密算法示例
2014/05/07 PHP
PHP实现的进度条效果详解
2016/05/03 PHP
PHP PDOStatement::closeCursor讲解
2019/01/30 PHP
比较简单的一个符合web标准的JS调用flash方法
2007/11/29 Javascript
JS获取屏幕,浏览器窗口大小,网页高度宽度(实现代码)
2013/12/17 Javascript
WordPress 单页面上一页下一页的实现方法【附代码】
2016/03/10 Javascript
AngularJS基础 ng-include 指令简单示例
2016/08/01 Javascript
jQuery插件Easyui设置datagrid的pageNumber导致两次请求问题的解决方法
2016/08/06 Javascript
深入浅析ES6 Class 中的 super 关键字
2017/10/20 Javascript
JS实现求5的阶乘示例
2019/01/21 Javascript
微信小程序绑定手机号获取验证码功能
2019/10/22 Javascript
ant design实现圈选功能
2019/12/17 Javascript
vue学习笔记之作用域插槽实例分析
2020/02/01 Javascript
微信小程序实现带放大效果的轮播图
2020/05/26 Javascript
[02:39]DOTA2国际邀请赛助威团西雅图第一天
2013/08/08 DOTA
[04:21]狐狸妈带你到现场 DOTA2 TI中国区预选赛线下赛路线指引
2014/05/22 DOTA
Python正则表达式指南 推荐
2018/10/09 Python
python 实现UTC时间加减的方法
2018/12/31 Python
Python字典的核心底层原理讲解
2019/01/24 Python
浅析Python 读取图像文件的性能对比
2019/03/07 Python
python使用百度文字识别功能方法详解
2019/07/23 Python
python抓取多种类型的页面方法实例
2019/11/20 Python
浅谈Pytorch torch.optim优化器个性化的使用
2020/02/20 Python
执行Python程序时模块报错问题
2020/03/26 Python
css3类选择器之结合元素选择器和多类选择器用法
2017/03/09 HTML / CSS
HTML5之消息通知的使用(Web Notification)
2018/10/30 HTML / CSS
金融管理毕业生求职信
2014/03/03 职场文书
关于感恩的演讲稿500字
2014/08/26 职场文书
竞选纪律委员演讲稿
2014/09/13 职场文书
意外死亡赔偿协议书
2014/10/14 职场文书
预备党员转正党小组意见
2015/06/01 职场文书
员工聘用合同范本
2015/09/21 职场文书