Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈python中列表、字符串、字典的常用操作
Sep 19 Python
python单线程文件传输的实例(C/S)
Feb 13 Python
python爬虫 Pyppeteer使用方法解析
Sep 28 Python
python双端队列原理、实现与使用方法分析
Nov 27 Python
python3 tcp的粘包现象和解决办法解析
Dec 09 Python
对Tensorflow中tensorboard日志的生成与显示详解
Feb 04 Python
浅析python标准库中的glob
Mar 13 Python
Django实现从数据库中获取到的数据转换为dict
Mar 27 Python
jupyter notebook 添加kernel permission denied的操作
Apr 21 Python
python logging通过json文件配置的步骤
Apr 27 Python
Python中的Cookie模块如何使用
Jun 04 Python
编译 pycaffe时报错:fatal error: numpy/arrayobject.h没有那个文件或目录
Nov 29 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
PHP Ajax实现页面无刷新发表评论
2007/01/02 PHP
PHP提取数据库内容中的图片地址并循环输出
2010/03/21 PHP
Yii2.0使用阿里云OSS的SDK上传图片、下载、删除图片示例
2017/09/20 PHP
Laravel框架集成UEditor编辑器的方法图文与实例详解
2019/04/17 PHP
Javascript 实现TreeView CheckBox全选效果
2010/01/11 Javascript
为JavaScript类型增加方法的实现代码(增加功能)
2011/12/29 Javascript
nodejs教程 安装express及配置app.js文件的详细步骤
2013/05/11 NodeJs
js转义字符介绍
2013/11/05 Javascript
浅析JavaScript中的同名标识符优先级
2013/12/06 Javascript
JS批量修改PS中图层名称的方法
2014/01/26 Javascript
jQuery下拉友情链接美化效果代码分享
2015/08/26 Javascript
基于jQuery实现简单的折叠菜单效果
2015/11/23 Javascript
详解nodeJS中读写文件方法的区别
2017/03/06 NodeJs
JavaScript实现无刷新上传预览图片功能
2017/08/02 Javascript
JavaScript面向对象精要(下部)
2017/09/12 Javascript
nodejs require js文件入口,在package.json中指定默认入口main方法
2018/10/10 NodeJs
jQuery实现数字自动增加或者减少的动画效果示例
2018/12/11 jQuery
详解react阻止无效重渲染的多种方式
2018/12/11 Javascript
在LayUI图片上传中,解决由跨域问题引起的请求接口错误的方法
2019/09/24 Javascript
[02:29]完美世界高校联赛上海赛区回顾
2015/12/15 DOTA
Python ldap实现登录实例代码
2016/09/30 Python
python 调用win32pai 操作cmd的方法
2017/05/28 Python
使用python生成目录树
2018/03/29 Python
python3写的简单本地文件上传服务器实例
2018/06/04 Python
Python netmiko模块的使用
2020/02/14 Python
Python3开发实例之非关系型图数据库Neo4j安装方法及Python3连接操作Neo4j方法实例
2020/03/18 Python
日本最大的旅游网站:Rakuten Travel(乐天旅游)
2018/08/02 全球购物
班级心理活动总结
2014/07/04 职场文书
演讲比赛的活动方案
2014/08/28 职场文书
党员教师群众路线个人整改措施
2014/10/28 职场文书
食堂卫生管理制度
2015/08/04 职场文书
Python 用户输入和while循环的操作
2021/05/23 Python
详解Java分布式事务的 6 种解决方案
2021/06/26 Java/Android
Java实现二分搜索树的示例代码
2022/03/17 Java/Android
ipad隐藏软件app图标方法
2022/04/19 数码科技
Java使用HttpClient实现文件下载
2022/08/14 Java/Android