Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
压缩包密码破解示例分享(类似典破解)
Jan 17 Python
浅要分析Python程序与C程序的结合使用
Apr 07 Python
Hadoop中的Python框架的使用指南
Apr 22 Python
儿童python练习实例
May 27 Python
python numpy 部分排序 寻找最大的前几个数的方法
Jun 27 Python
使用python实现简单五子棋游戏
Jun 18 Python
python读写csv文件并增加行列的实例代码
Aug 01 Python
Django app配置多个数据库代码实例
Dec 17 Python
Python如何实现自带HTTP文件传输服务
Jul 08 Python
Scrapy基于scrapy_redis实现分布式爬虫部署的示例
Sep 29 Python
python 中的jieba分词库
Nov 23 Python
Python函数中apply、map、applymap的区别
Nov 27 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
php中用文本文件做数据库的实现方法
2008/03/27 PHP
PHP 根据IP地址控制访问的代码
2010/04/22 PHP
应用开发中涉及到的css和php笔记分享
2011/08/02 PHP
PHP解码unicode编码的中文字符代码分享
2014/08/13 PHP
PHP多态代码实例
2015/06/26 PHP
Javascript中匿名函数的多种调用方式总结
2013/12/06 Javascript
JavaScript编程中布尔对象的基本使用
2015/10/25 Javascript
JavaScript正则表达式exec/g实现多次循环用法示例
2017/01/17 Javascript
详解Vue2.0 事件派发与接收
2017/09/05 Javascript
LayUI表格批量删除方法
2018/08/15 Javascript
Vue.js 中的 v-model 指令及绑定表单元素的方法
2018/12/03 Javascript
深入理解vue中的slot与slot-scope
2019/04/22 Javascript
为什么Vue3.0使用Proxy实现数据监听(defineProperty表示不背这个锅)
2019/10/14 Javascript
微信小程序如何实现精确的日期时间选择器
2020/01/21 Javascript
vscode 配置vue+vetur+eslint+prettier自动格式化功能
2020/03/23 Javascript
详解element上传组件before-remove钩子问题解决
2020/04/08 Javascript
解决vue项目input输入框双向绑定数据不实时生效问题
2020/08/05 Javascript
vue实现按钮切换图片
2021/01/20 Vue.js
Python浅拷贝与深拷贝用法实例
2015/05/09 Python
深入解析Python编程中JSON模块的使用
2015/10/15 Python
Python利用itchat对微信中好友数据实现简单分析的方法
2017/11/21 Python
十分钟利用Python制作属于你自己的个性logo
2018/05/07 Python
python面向对象多线程爬虫爬取搜狐页面的实例代码
2018/05/31 Python
Python实现的在特定目录下导入模块功能分析
2019/02/11 Python
python三方库之requests的快速上手
2019/03/04 Python
django使用多个数据库的方法实例
2021/03/04 Python
专科毕业生就业推荐信
2013/11/01 职场文书
《桂林山水》教学反思
2014/02/08 职场文书
《音乐之都维也纳》教学反思
2014/04/16 职场文书
供货协议书范本
2014/04/22 职场文书
竞选大学学委演讲稿
2014/09/13 职场文书
2014年团委工作总结
2014/11/13 职场文书
2014年师德师风工作总结
2014/11/25 职场文书
民事起诉状范文
2015/05/19 职场文书
使用Html+Css实现简易导航栏功能(导航栏遇到鼠标切换背景颜色)
2021/04/07 HTML / CSS
Apache POI操作批量导入MySQL数据库
2022/06/21 Servers