Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中请使用isinstance()判断变量类型
Aug 25 Python
Python入门篇之列表和元组
Oct 17 Python
python和bash统计CPU利用率的方法
Jul 10 Python
Python连接phoenix的方法示例
Sep 29 Python
Python调用C语言的方法【基于ctypes模块】
Jan 22 Python
pytorch 输出中间层特征的实例
Aug 17 Python
python数组循环处理方法
Aug 26 Python
OpenCV模板匹配matchTemplate的实现
Oct 18 Python
python函数局部变量、全局变量、递归知识点总结
Nov 15 Python
Django单元测试中Fixtures的使用方法
Feb 26 Python
python绘制高斯曲线
Feb 19 Python
Python识别花卉种类鉴定网络热门植物并自动整理分类
Apr 08 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
php设计模式 FlyWeight (享元模式)
2011/06/26 PHP
php日期转时间戳,指定日期转换成时间戳
2012/07/17 PHP
Laravel框架数据库CURD操作、连贯操作总结
2014/09/03 PHP
CI配置多数据库访问的方法
2016/03/28 PHP
Yii框架表单模型和验证用法
2016/05/20 PHP
Centos7安装swoole扩展操作示例
2020/03/26 PHP
详细分析PHP 命名空间(namespace)
2020/06/30 PHP
基于jquery1.4.2的仿flash超炫焦点图播放效果
2010/04/20 Javascript
javascript date格式化示例
2013/09/25 Javascript
怎么判断js脚本加载完成
2014/02/28 Javascript
javascript中select下拉框的用法总结
2016/01/07 Javascript
有关文件上传 非ajax提交 得到后台数据问题
2016/10/12 Javascript
微信开发 微信授权详解
2016/10/21 Javascript
laydate 显示结束时间不小于开始时间的实例
2017/08/11 Javascript
angular4中*ngFor不能对返回来的对象进行循环的解决方法
2018/09/12 Javascript
vue动画打包后失效问题的解决方法
2018/09/18 Javascript
详解webpack-dev-middleware 源码解读
2020/03/23 Javascript
[52:20]VP vs VG Supermajor小组赛 B组胜者组决赛 BO3 第一场 6.2
2018/06/03 DOTA
python 生成目录树及显示文件大小的代码
2009/07/23 Python
简单的抓取淘宝图片的Python爬虫
2014/12/25 Python
Python模拟登录验证码(代码简单)
2016/02/06 Python
Python中functools模块函数解析
2017/03/12 Python
python3中获取文件当前绝对路径的两种方法
2018/04/26 Python
python3第三方爬虫库BeautifulSoup4安装教程
2018/06/19 Python
详解pandas中MultiIndex和对象实际索引不一致问题
2019/07/23 Python
python中Mako库实例用法
2020/12/31 Python
HTML5注册页面示例代码
2014/03/27 HTML / CSS
New Balance美国官网:运动鞋和健身服装
2017/04/11 全球购物
马耳他航空公司官方网站:Air Malta
2019/05/15 全球购物
IGK Hair官网:喷雾、洗发水、护发素等
2020/11/03 全球购物
实习期自我鉴定
2013/10/11 职场文书
医学毕业生自我鉴定
2013/10/30 职场文书
第一批党的群众路线教育实践活动工作总结
2014/03/03 职场文书
毕业论文答辩开场白
2015/05/27 职场文书
2015年社区国庆节活动总结
2015/07/30 职场文书
2016年百日安全生产活动总结
2016/04/06 职场文书