Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python笔记(叁)继续学习
Oct 24 Python
二种python发送邮件实例讲解(python发邮件附件可以使用email模块实现)
Dec 03 Python
Python使用OpenCV进行标定
May 08 Python
python内存动态分配过程详解
Jul 15 Python
Python3 tkinter 实现文件读取及保存功能
Sep 12 Python
解决python中的幂函数、指数函数问题
Nov 25 Python
python错误调试及单元文档测试过程解析
Dec 19 Python
django2.2 和 PyMySQL版本兼容问题
Feb 17 Python
python 日志 logging模块详细解析
Mar 31 Python
Pygame框架实现飞机大战
Aug 07 Python
使用Python实现音频双通道分离
Dec 25 Python
Python中基础数据类型 set集合知识点总结
Aug 02 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
thinkphp实现把数据库中的列的值存到下拉框中的方法
2017/01/20 PHP
Google韩国首页图标动画效果
2007/08/26 Javascript
JS中如何判断传过来的JSON数据中是否存在某字段
2014/08/18 Javascript
JavaScript用JQuery呼叫Server端方法示例代码
2014/09/03 Javascript
jquery实现将获取的颜色值转换为十六进制形式的方法
2014/12/20 Javascript
js实现交换运动效果的方法
2015/04/10 Javascript
jquery实现的仿天猫侧导航tab切换效果
2015/08/24 Javascript
基于jquery实现日历签到功能
2020/09/11 Javascript
深入浅析JavaScript中的arguments对象(强力推荐)
2016/06/03 Javascript
JS实现的跨浏览器解析XML文件实例
2016/06/21 Javascript
Nodejs抓取html页面内容(推荐)
2016/08/11 NodeJs
将鼠标焦点定位到文本框最后(代码分享)
2017/01/11 Javascript
jQuery取得元素标签名称小结(附代码)
2017/08/16 jQuery
JavaScript数组方法的错误使用例子
2018/09/13 Javascript
Vue.js中 v-model 指令的修饰符详解
2018/12/03 Javascript
详解用JS添加和删除class类名
2019/03/25 Javascript
JS原形与原型链深入详解
2020/05/09 Javascript
nuxt 自定义 auth 中间件实现令牌的持久化操作
2020/11/05 Javascript
python网络编程学习笔记(一)
2014/06/09 Python
Python函数返回值实例分析
2015/06/08 Python
浅谈Matplotlib简介和pyplot的简单使用——文本标注和箭头
2018/01/09 Python
python写一个md5解密器示例
2018/02/23 Python
查看django版本的方法分享
2018/05/14 Python
Pandas之DataFrame对象的列和索引之间的转化
2019/06/25 Python
Django中多种重定向方法使用详解
2019/07/17 Python
python隐藏类中属性的3种实现方法
2019/12/19 Python
Django DRF APIView源码运行流程详解
2020/08/17 Python
基于python实现简单网页服务器代码实例
2020/09/14 Python
Python + opencv对拍照得到的图片进行背景去除的实现方法
2020/11/18 Python
Python爬虫定时计划任务的几种常见方法(推荐)
2021/01/15 Python
什么时候用assert
2015/05/08 面试题
高等教育学专业自荐书
2014/06/17 职场文书
2016个人先进事迹材料范文
2016/03/01 职场文书
四年级作文之植物
2019/09/20 职场文书
微信小程序和php的登录实现
2021/04/01 PHP
css中有哪些方式可以隐藏页面元素及区别
2022/06/16 HTML / CSS