Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python解决计数原理问题的方法
Aug 04 Python
python嵌套字典比较值与取值的实现示例
Nov 03 Python
浅谈python 里面的单下划线与双下划线的区别
Dec 01 Python
python随机取list中的元素方法
Apr 08 Python
基于python 处理中文路径的终极解决方法
Apr 12 Python
Django读取Mysql数据并显示在前端的实例
May 27 Python
python远程连接服务器MySQL数据库
Jul 02 Python
python 运用Django 开发后台接口的实例
Dec 11 Python
python tkinter库实现气泡屏保和锁屏
Jul 29 Python
Python脚本实现监听服务器的思路代码详解
May 28 Python
python logging 重复写日志问题解决办法详解
Aug 04 Python
pandas数值排序的实现实例
Jul 25 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
PHP 危险函数全解析
2009/09/09 PHP
php中session_unset与session_destroy的区别分析
2011/06/16 PHP
ThinkPHP项目分组配置方法分析
2016/03/23 PHP
javascript写的一个链表实现代码
2009/10/25 Javascript
CutePsWheel javascript libary 控制输入文本框为可使用滚轮控制的js库
2010/02/07 Javascript
JQuery插件fancybox无法在弹出层使用左右键的解决办法
2013/12/25 Javascript
jQuery中delegate()方法用法实例
2015/01/19 Javascript
你不知道的高性能JAVASCRIPT
2016/01/18 Javascript
jQuery Validate插件实现表单验证
2016/08/19 Javascript
JQuery和HTML5 Canvas实现弹幕效果
2017/01/04 Javascript
nodejs个人博客开发第一步 准备工作
2017/04/12 NodeJs
详解用node-images 打造简易图片服务器
2017/05/08 Javascript
AngularJS中重新加载当前路由页面的方法
2018/03/09 Javascript
JavaScript设计模式之责任链模式实例分析
2019/01/16 Javascript
基于小程序请求接口wx.request封装的类axios请求
2020/07/02 Javascript
[32:56]完美世界DOTA2联赛PWL S3 Rebirth vs CPG 第二场 12.11
2020/12/16 DOTA
python编程实现12306的一个小爬虫实例
2017/12/27 Python
python  创建一个保留重复值的列表的补码
2018/10/15 Python
Python字符串逆序输出的实例讲解
2019/02/16 Python
Python魔法方法功能与用法简介
2019/04/04 Python
对python中基于tcp协议的通信(数据传输)实例讲解
2019/07/22 Python
浅谈pytorch卷积核大小的设置对全连接神经元的影响
2020/01/10 Python
Python os模块常用方法和属性总结
2020/02/20 Python
Python如何使用paramiko模块连接linux
2020/03/18 Python
Python 使用 PyQt5 开发的关机小工具分享
2020/07/16 Python
Django DRF认证组件流程实现原理详解
2020/08/17 Python
CSS3实现大小不一的粒子旋转加载动画
2016/04/21 HTML / CSS
简单整理HTML5的基本特性和语法
2016/02/18 HTML / CSS
哪些情况下不应该使用索引
2015/07/20 面试题
师范应届生语文教师求职信
2013/10/29 职场文书
亮剑精神演讲稿
2014/05/23 职场文书
医药公司采购员岗位职责
2015/04/03 职场文书
2015年三好一满意工作总结
2015/07/24 职场文书
selenium.webdriver中add_argument方法常用参数表
2021/04/08 Python
CSS变量实现主题切换的方法
2021/06/23 HTML / CSS
关于MySQL中explain工具的使用
2023/05/08 MySQL