浅谈tensorflow模型保存为pb的各种姿势


Posted in Python onMay 25, 2020

一,直接保存pb

1, 首先我们当然可以直接在tensorflow训练中直接保存为pb为格式,保存pb的好处就是使用场景是实现创建模型与使用模型的解耦,使得创建模型与使用模型的解耦,使得前向推导inference代码统一。另外的好处就是保存为pb的时候,模型的变量会变成固定的,导致模型的大小会大大减小。

这里稍稍解释下pb:是MetaGraph的protocol buffer格式的文件,MetaGraph包括计算图,数据流,以及相关的变量和输入输出

主要使用tf.SavedModelBuilder来完成这个工作,并且可以把多个计算图保存到一个pb文件中,如果有多个MetaGraph,那么只会保留第一个MetaGraph的版本号。

保持pb的文件代码:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32, name='x')
 y = tf.placeholder(tf.int32, name='y')
 b = tf.Variable(1, name='b')
 xy = tf.multiply(x, y)
 # 这里的输出需要加上name属性
 op = tf.add(xy, b, name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 
 # 测试 OP
 feed_dict = {x: 10, y: 3}
 print(sess.run(op, feed_dict))
 
 # 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # 输出
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31

其实主要是:

# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
# 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())

1.1 加载测试代码

from tensorflow.python.platform import gfile
 
sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') # 导入计算图
 
# 需要有一个初始化的过程 
sess.run(tf.global_variables_initializer())
 
# 需要先复原变量
print(sess.run('b:0'))
# 1
 
# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
 
op = sess.graph.get_tensor_by_name('op_to_store:0')
 
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26

2,第二种就是采用上述的那API来进行保存

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32, name='x')
 y = tf.placeholder(tf.int32, name='y')
 b = tf.Variable(1, name='b')
 xy = tf.multiply(x, y)
 # 这里的输出需要加上name属性
 op = tf.add(xy, b, name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 
 # 测试 OP
 feed_dict = {x: 10, y: 3}
 print(sess.run(op, feed_dict))
 
 # 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31
 
 
 # 官网有误,写成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
 # 输入输出信息字典,额外的信息
 builder.add_meta_graph_and_variables(sess,
          ['cpu_server_1'])
 
# 添加第二个 MetaGraphDef 
#with tf.Session(graph=tf.Graph()) as sess:
# ...
# builder.add_meta_graph([tag_constants.SERVING])
#...
 
builder.save() # 保存 PB 模型

核心就是采用了:

# 官网有误,写成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
 # 输入输出信息字典,额外的信息
 builder.add_meta_graph_and_variables(sess,
          ['cpu_server_1'])

2.1 对应的测试代码为:

with tf.Session(graph=tf.Graph()) as sess:
 tf.saved_model.loader.load(sess, ['cpu_1'], pb_file_path+'savemodel')
 sess.run(tf.global_variables_initializer())
 
 input_x = sess.graph.get_tensor_by_name('x:0')
 input_y = sess.graph.get_tensor_by_name('y:0')
 
 op = sess.graph.get_tensor_by_name('op_to_store:0')
 
 ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
 print(ret)
# 只需要指定要恢复模型的 session,模型的 tag,模型的保存路径即可,使用起来更加简单

这样和之前的导入pb模型一样,也是要知道tensor的name,那么如何在不知道tensor name的情况下使用呢,给add_meta_graph_and_variables方法传入第三个参数,signature_def_map即可。

二,从ckpt进行加载

使用tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
 saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

浅谈tensorflow模型保存为pb的各种姿势

checkpoint是检查点的文件,文件保存了一个目录下所有的模型文件列表

model.ckpt.meta文件保存了Tensorflow计算图的结果,可以理解为神经网络的网络结构,该文件可以被tf.train.import_meta_graph加载到当前默认的图来使用

ckpt.data是保存模型中每个变量的取值

方法一, tensorflow提供了convert_variables_to_constants()方法,改方法可以固化模型结构,将计算图中的变量取值以常量的形式保存

ckpt转换pb格式过程如下:

1,通过传入ckpt模型的路径得到模型的图和变量数据

2,通过import_meta_graph导入模型中的图

3,通过saver.restore从模型中恢复图中各个变量的数据

4,通过graph_util.convert_variables_to_constants将模型持久化

import tensorflow as tf 
from tensorflow.python.framework import graph_util
from tensorflow.pyton.platform import gfile
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=input_graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in graph.get_operations():
  #  print(op.name, op.values())

函数freeze_graph中,最重要的就是指定输出节点的名称,这个节点名称是原模型存在的结点,注意节点名称与张量名称的区别:

如:“input:0”是张量的名称,而“input”表示的是节点的名称

源码中通过graph = tf.get_default_graph()获得默认图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此就必须执行tf.train.import_meta_graph,再执行tf.get_default_graph()

1.2 一个小工具

tensorflow打印pb模型的所有节点

from tensorflow.python.framework import tensor_util
from google.protobuf import text_format 
import tensorflow as tf 
from tensorflow.python.platform import gfile 
from tensorflow.python.framework import tensor_util
 
pb_path = './model.pb'
 
with tf.Session() as sess:
 with gfile.FastGFile(pb_path,'rb') as f:
  graph_def = tf.GraphDef()
 
  graph_def.ParseFromString(f.read())
  tf.import_graph_def(graph_def,name='')
  for i,n in enumerate(graph_def.node):
   print("Name of the node -%s"%n.name)
tensorflow打印ckpt的所有节点

from tensorflow.python import pywrap_tensorflow
checkpoint_path = './_checkpoint/hed.ckpt-130'
 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
 print("tensor_name:",key)

方法二,除了上述办法外还有一种是需要通过源码的,这样既可以得到输出节点,还可以自定义输入节点。

import tensorflow as tf 
 
def model(input):
 net = tf.layers.conv2d(input,filters=32,kernel_size=3)
 net = tf.layers.batch_normalization(net,fused=False)
 net = tf.layers.separable_conv2d(net,32,3)
 net = tf.layers.conv2d(net,filters=32,kernel_size=3,name='output')
 
 return net 
 
input_node = tf.placeholder(tf.float32,[1,480,480,3],name = 'image')
output_node_names = 'head_neck_count/BiasAdd'
ckpt = ckpt_path 
pb = pb_path 
 
with tf.Session() as sess:
 model1 = model(input_node)
 sess.run(tf.global_variables_initializer())
 output_node_names = 'output/BiasAdd'
 
 input_graph_def = tf.get_default_graph().as_graph_def()
 output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names.split(','))
 
with tf.gfile.GFile(pb,'wb') as f:
 f.write(output_graph_def.SerializeToString())

注意:

节点名称和张量名称区别

类似于output是节点名称

类似于output:0是张量名称

方法三,其实是方法一的延伸可以配合tensorflow自带的一些工具来进行完成

freeze_graph

总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):

1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)

2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。

3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False

4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。

5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。

6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all

7、filename_tensor_name:(可选)已弃用。默认:save/Const:0

8、output_graph:(必选)用来保存整合后的模型输出文件。

9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)

10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。

所以还是建议选择方法三

导出pb后的测试代码如下:下图是比较完成的测试代码与导出代码。

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2018-08-29 17:46:50
 @info :
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3
 
def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()
  with open(pb_path, "rb") as f:
   output_graph_def.ParseFromString(f.read())
   tf.import_graph_def(output_graph_def, name="")
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
 
   # 定义输入的张量名称,对应网络结构的输入张量
   # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
   input_image_tensor = sess.graph.get_tensor_by_name("input:0")
   input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
   input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
   # 定义输出的张量名称
   output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
   # 读取测试图片
   im=read_image(image_path,resize_height,resize_width,normalization=True)
   im=im[np.newaxis,:]
   # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
   # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
   out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
              input_keep_prob_tensor:1.0,
              input_is_training_tensor:False})
   print("out:{}".format(out))
   score = tf.nn.softmax(out, name='pre')
   class_id = tf.argmax(score, 1)
   print "pre class_id:{}".format(sess.run(class_id))
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=sess.graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in sess.graph.get_operations():
  #  print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=input_graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in graph.get_operations():
  #  print(op.name, op.values())
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 测试pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

以上这篇浅谈tensorflow模型保存为pb的各种姿势就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Python的身份证号码自动生成程序
Aug 15 Python
Python中replace方法实例分析
Aug 20 Python
python输出当前目录下index.html文件路径的方法
Apr 28 Python
Python使用matplotlib填充图形指定区域代码示例
Jan 16 Python
Python 输入一个数字判断成绩分数等级的方法
Nov 15 Python
python利用插值法对折线进行平滑曲线处理
Dec 25 Python
Python利用itchat库向好友或者公众号发消息的实例
Feb 21 Python
python实现控制COM口的示例
Jul 03 Python
解决Django Static内容不能加载显示的问题
Jul 28 Python
Python GUI库PyQt5图形和特效样式QSS介绍
Feb 25 Python
Python3标准库glob文件名模式匹配的问题
Mar 13 Python
浅谈TensorFlow之稀疏张量表示
Jun 30 Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
keras 如何保存最佳的训练模型
May 25 #Python
keras处理欠拟合和过拟合的实例讲解
May 25 #Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
You might like
php+highchats生成动态统计图
2014/05/21 PHP
PHP中的排序函数sort、asort、rsort、krsort、ksort区别分析
2014/08/18 PHP
PHP的Yii框架中移除组件所绑定的行为的方法
2016/03/18 PHP
PHP判断JSON对象是否存在的方法(推荐)
2016/07/06 PHP
详解PHP防止直接访问.php 文件的实现方法
2017/07/28 PHP
setTimeout和setInterval的浏览器兼容性分析
2007/02/27 Javascript
js form 验证函数 当前比较流行的错误提示
2009/06/23 Javascript
通过JavaScript控制字体大小的代码
2011/10/04 Javascript
动态加载dtree.js树treeview(示例代码)
2013/12/17 Javascript
javascript将相对路径转绝对路径示例
2014/03/14 Javascript
JavaScript使用encodeURI()和decodeURI()获取字符串值的方法
2015/08/04 Javascript
JavaScript中访问id对象 属性的方式访问属性(实例代码)
2016/10/28 Javascript
javaScript生成支持中文带logo的二维码(jquery.qrcode.js)
2017/01/03 Javascript
利用Vue实现一个markdown编辑器实例代码
2019/05/19 Javascript
Vue 实现登录界面验证码功能
2020/01/03 Javascript
python进程类subprocess的一些操作方法例子
2014/11/22 Python
python numpy函数中的linspace创建等差数列详解
2017/10/13 Python
Python利用公共键如何对字典列表进行排序详解
2018/05/19 Python
python删除字符串中指定字符的方法
2018/08/13 Python
python opencv判断图像是否为空的实例
2019/01/26 Python
在Python中使用Neo4j的方法
2019/03/14 Python
使用python动态生成波形曲线的实现
2019/12/04 Python
python中Array和DataFrame相互转换的实例讲解
2021/02/03 Python
BrandAlley英国:法国折扣奢侈品网上零售商
2017/07/03 全球购物
KARATOV珠宝在线商店:俄罗斯珠宝品牌
2019/03/13 全球购物
CK巴西官方网站:Calvin Klein巴西
2019/07/19 全球购物
Pandora德国官网:购买潘多拉手链、戒指、项链和耳环
2020/02/20 全球购物
大四学生毕业自荐信
2013/11/07 职场文书
决心书标准格式
2014/03/11 职场文书
企业演讲比赛主持词
2014/03/18 职场文书
岗位说明书怎么写
2014/07/30 职场文书
三好学生事迹材料
2014/12/24 职场文书
导游词之南昌滕王阁
2019/11/29 职场文书
Go缓冲channel和非缓冲channel的区别说明
2021/04/25 Golang
详解MySQL连接挂死的原因
2021/05/18 MySQL
微信小程序基础教程之echart的使用
2021/06/01 Javascript