Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python datetime时间格式化去掉前导0
Jul 31 Python
Python中的包和模块实例
Nov 22 Python
Python操作MongoDB数据库PyMongo库使用方法
Apr 27 Python
python类继承用法实例分析
May 27 Python
Python2.x中文乱码问题解决方法
Jun 02 Python
Python判断文本中消息重复次数的方法
Apr 27 Python
Python基础之函数的定义与使用示例
Mar 23 Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 Python
用Python绘制漫步图实例讲解
Feb 26 Python
Python3 filecmp模块测试比较文件原理解析
Mar 23 Python
Python读取yaml文件的详细教程
Jul 21 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
老照片 - 几十年前的收音机与人
2021/03/02 无线电
PHP下一个非常全面获取图象信息的函数
2008/11/20 PHP
php adodb分页实现代码
2009/03/19 PHP
解析php 版获取重定向后的地址(代码)
2013/06/26 PHP
PHP实现的比较完善的购物车类
2014/12/02 PHP
php如何连接sql server
2015/10/16 PHP
thinkphp实现把数据库中的列的值存到下拉框中的方法
2017/01/20 PHP
深入解析Laravel5.5中的包自动发现Package Auto Discovery
2017/09/13 PHP
游戏人文件夹程序 ver 3.0
2006/07/14 Javascript
VUEJS实战之利用laypage插件实现分页(3)
2016/06/13 Javascript
angular中的http拦截器Interceptors的实现
2017/02/21 Javascript
Bootstrap Table使用整理(二)
2017/06/09 Javascript
详解Angular 开发环境搭建
2017/06/22 Javascript
vue.js数据绑定的方法(单向、双向和一次性绑定)
2017/07/13 Javascript
深入浅析vue组件间事件传递
2017/12/29 Javascript
Vue条件循环判断+计算属性+绑定样式v-bind的实例
2018/09/18 Javascript
Nodejs对postgresql基本操作的封装方法
2019/02/20 NodeJs
微信小程序iOS下拉白屏晃动问题解决方案
2019/10/12 Javascript
Vue中错误图片的处理的实现代码
2019/11/07 Javascript
[01:35]2018年度CS GO最佳战队-完美盛典
2018/12/17 DOTA
使用python实现省市三级菜单效果
2016/01/20 Python
利用Python自动监控网站并发送邮件告警的方法
2016/08/24 Python
python读取与写入csv格式文件的示例代码
2017/12/16 Python
Python实现多条件筛选目标数据功能【测试可用】
2018/06/13 Python
Node.js 和 Python之间该选择哪个?
2020/08/05 Python
深入理解css中vertical-align属性
2017/04/18 HTML / CSS
CSS3中的display:grid,网格布局介绍
2019/10/30 HTML / CSS
Christys’ Hats官网:英国帽子制造商
2018/11/28 全球购物
Mountain Hardwear官网:攀岩服装和户外装备
2019/09/26 全球购物
Paper Cape官网:美国婴儿和儿童服装品牌
2019/11/02 全球购物
华美博弈C/VC工程师笔试试题
2012/07/16 面试题
《锄禾》教学反思
2014/04/08 职场文书
酒店宣传语大全
2015/07/13 职场文书
少先大队干部竞选稿
2015/11/20 职场文书
医务人员岗前培训心得体会
2016/01/08 职场文书
导游词之凤凰古城
2019/10/22 职场文书