浅谈tensorflow模型保存为pb的各种姿势


Posted in Python onMay 25, 2020

一,直接保存pb

1, 首先我们当然可以直接在tensorflow训练中直接保存为pb为格式,保存pb的好处就是使用场景是实现创建模型与使用模型的解耦,使得创建模型与使用模型的解耦,使得前向推导inference代码统一。另外的好处就是保存为pb的时候,模型的变量会变成固定的,导致模型的大小会大大减小。

这里稍稍解释下pb:是MetaGraph的protocol buffer格式的文件,MetaGraph包括计算图,数据流,以及相关的变量和输入输出

主要使用tf.SavedModelBuilder来完成这个工作,并且可以把多个计算图保存到一个pb文件中,如果有多个MetaGraph,那么只会保留第一个MetaGraph的版本号。

保持pb的文件代码:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32, name='x')
 y = tf.placeholder(tf.int32, name='y')
 b = tf.Variable(1, name='b')
 xy = tf.multiply(x, y)
 # 这里的输出需要加上name属性
 op = tf.add(xy, b, name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 
 # 测试 OP
 feed_dict = {x: 10, y: 3}
 print(sess.run(op, feed_dict))
 
 # 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # 输出
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31

其实主要是:

# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
# 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())

1.1 加载测试代码

from tensorflow.python.platform import gfile
 
sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') # 导入计算图
 
# 需要有一个初始化的过程 
sess.run(tf.global_variables_initializer())
 
# 需要先复原变量
print(sess.run('b:0'))
# 1
 
# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
 
op = sess.graph.get_tensor_by_name('op_to_store:0')
 
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26

2,第二种就是采用上述的那API来进行保存

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32, name='x')
 y = tf.placeholder(tf.int32, name='y')
 b = tf.Variable(1, name='b')
 xy = tf.multiply(x, y)
 # 这里的输出需要加上name属性
 op = tf.add(xy, b, name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 
 # 测试 OP
 feed_dict = {x: 10, y: 3}
 print(sess.run(op, feed_dict))
 
 # 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31
 
 
 # 官网有误,写成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
 # 输入输出信息字典,额外的信息
 builder.add_meta_graph_and_variables(sess,
          ['cpu_server_1'])
 
# 添加第二个 MetaGraphDef 
#with tf.Session(graph=tf.Graph()) as sess:
# ...
# builder.add_meta_graph([tag_constants.SERVING])
#...
 
builder.save() # 保存 PB 模型

核心就是采用了:

# 官网有误,写成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
 # 输入输出信息字典,额外的信息
 builder.add_meta_graph_and_variables(sess,
          ['cpu_server_1'])

2.1 对应的测试代码为:

with tf.Session(graph=tf.Graph()) as sess:
 tf.saved_model.loader.load(sess, ['cpu_1'], pb_file_path+'savemodel')
 sess.run(tf.global_variables_initializer())
 
 input_x = sess.graph.get_tensor_by_name('x:0')
 input_y = sess.graph.get_tensor_by_name('y:0')
 
 op = sess.graph.get_tensor_by_name('op_to_store:0')
 
 ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
 print(ret)
# 只需要指定要恢复模型的 session,模型的 tag,模型的保存路径即可,使用起来更加简单

这样和之前的导入pb模型一样,也是要知道tensor的name,那么如何在不知道tensor name的情况下使用呢,给add_meta_graph_and_variables方法传入第三个参数,signature_def_map即可。

二,从ckpt进行加载

使用tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
 saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

浅谈tensorflow模型保存为pb的各种姿势

checkpoint是检查点的文件,文件保存了一个目录下所有的模型文件列表

model.ckpt.meta文件保存了Tensorflow计算图的结果,可以理解为神经网络的网络结构,该文件可以被tf.train.import_meta_graph加载到当前默认的图来使用

ckpt.data是保存模型中每个变量的取值

方法一, tensorflow提供了convert_variables_to_constants()方法,改方法可以固化模型结构,将计算图中的变量取值以常量的形式保存

ckpt转换pb格式过程如下:

1,通过传入ckpt模型的路径得到模型的图和变量数据

2,通过import_meta_graph导入模型中的图

3,通过saver.restore从模型中恢复图中各个变量的数据

4,通过graph_util.convert_variables_to_constants将模型持久化

import tensorflow as tf 
from tensorflow.python.framework import graph_util
from tensorflow.pyton.platform import gfile
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=input_graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in graph.get_operations():
  #  print(op.name, op.values())

函数freeze_graph中,最重要的就是指定输出节点的名称,这个节点名称是原模型存在的结点,注意节点名称与张量名称的区别:

如:“input:0”是张量的名称,而“input”表示的是节点的名称

源码中通过graph = tf.get_default_graph()获得默认图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此就必须执行tf.train.import_meta_graph,再执行tf.get_default_graph()

1.2 一个小工具

tensorflow打印pb模型的所有节点

from tensorflow.python.framework import tensor_util
from google.protobuf import text_format 
import tensorflow as tf 
from tensorflow.python.platform import gfile 
from tensorflow.python.framework import tensor_util
 
pb_path = './model.pb'
 
with tf.Session() as sess:
 with gfile.FastGFile(pb_path,'rb') as f:
  graph_def = tf.GraphDef()
 
  graph_def.ParseFromString(f.read())
  tf.import_graph_def(graph_def,name='')
  for i,n in enumerate(graph_def.node):
   print("Name of the node -%s"%n.name)
tensorflow打印ckpt的所有节点

from tensorflow.python import pywrap_tensorflow
checkpoint_path = './_checkpoint/hed.ckpt-130'
 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
 print("tensor_name:",key)

方法二,除了上述办法外还有一种是需要通过源码的,这样既可以得到输出节点,还可以自定义输入节点。

import tensorflow as tf 
 
def model(input):
 net = tf.layers.conv2d(input,filters=32,kernel_size=3)
 net = tf.layers.batch_normalization(net,fused=False)
 net = tf.layers.separable_conv2d(net,32,3)
 net = tf.layers.conv2d(net,filters=32,kernel_size=3,name='output')
 
 return net 
 
input_node = tf.placeholder(tf.float32,[1,480,480,3],name = 'image')
output_node_names = 'head_neck_count/BiasAdd'
ckpt = ckpt_path 
pb = pb_path 
 
with tf.Session() as sess:
 model1 = model(input_node)
 sess.run(tf.global_variables_initializer())
 output_node_names = 'output/BiasAdd'
 
 input_graph_def = tf.get_default_graph().as_graph_def()
 output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names.split(','))
 
with tf.gfile.GFile(pb,'wb') as f:
 f.write(output_graph_def.SerializeToString())

注意:

节点名称和张量名称区别

类似于output是节点名称

类似于output:0是张量名称

方法三,其实是方法一的延伸可以配合tensorflow自带的一些工具来进行完成

freeze_graph

总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):

1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)

2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。

3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False

4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。

5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。

6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all

7、filename_tensor_name:(可选)已弃用。默认:save/Const:0

8、output_graph:(必选)用来保存整合后的模型输出文件。

9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)

10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。

所以还是建议选择方法三

导出pb后的测试代码如下:下图是比较完成的测试代码与导出代码。

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2018-08-29 17:46:50
 @info :
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3
 
def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()
  with open(pb_path, "rb") as f:
   output_graph_def.ParseFromString(f.read())
   tf.import_graph_def(output_graph_def, name="")
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
 
   # 定义输入的张量名称,对应网络结构的输入张量
   # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
   input_image_tensor = sess.graph.get_tensor_by_name("input:0")
   input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
   input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
   # 定义输出的张量名称
   output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
   # 读取测试图片
   im=read_image(image_path,resize_height,resize_width,normalization=True)
   im=im[np.newaxis,:]
   # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
   # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
   out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
              input_keep_prob_tensor:1.0,
              input_is_training_tensor:False})
   print("out:{}".format(out))
   score = tf.nn.softmax(out, name='pre')
   class_id = tf.argmax(score, 1)
   print "pre class_id:{}".format(sess.run(class_id))
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=sess.graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in sess.graph.get_operations():
  #  print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=input_graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in graph.get_operations():
  #  print(op.name, op.values())
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 测试pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

以上这篇浅谈tensorflow模型保存为pb的各种姿势就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python清除指定目录内所有文件中script的方法
Jun 30 Python
python 基础教程之Map使用方法
Jan 17 Python
python爬虫_微信公众号推送信息爬取的实例
Oct 23 Python
详解Python中的四种队列
May 21 Python
Python实现的合并两个有序数组算法示例
Mar 04 Python
详解python读取image
Apr 03 Python
Python中的字符串切片(截取字符串)的详解
May 15 Python
python中字符串数组逆序排列方法总结
Jun 23 Python
Django CSRF跨站请求伪造防护过程解析
Jul 31 Python
Python数据库小程序源代码
Sep 15 Python
Django框架ORM数据库操作实例详解
Nov 07 Python
使用Pandas将inf, nan转化成特定的值
Dec 19 Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
keras 如何保存最佳的训练模型
May 25 #Python
keras处理欠拟合和过拟合的实例讲解
May 25 #Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
You might like
PHP MVC框架skymvc支持多文件上传
2016/05/26 PHP
php利用云片网实现短信验证码功能的示例代码
2017/11/18 PHP
php中html_entity_decode实现HTML实体转义
2018/06/13 PHP
windows系统php环境安装swoole具体步骤
2021/03/04 PHP
通过ifame指向的页面高度调整iframe的高度
2006/10/05 Javascript
解析jquery获取父窗口的元素
2013/06/26 Javascript
js的[defer]和[async]属性
2014/11/24 Javascript
JavaScript统计字符串中每个字符出现次数完整实例
2016/01/28 Javascript
javascript实现右下角广告框效果
2017/02/01 Javascript
jQuery的三种bind/One/Live/On事件绑定使用方法
2017/02/23 Javascript
基于Vue自定义指令实现按钮级权限控制思路详解
2018/05/23 Javascript
小程序实现搜索界面 小程序实现推荐搜索列表效果
2019/05/18 Javascript
JavaScript获取当前url路径过程解析
2019/12/27 Javascript
如何在JavaScript中创建具有多个空格的字符串?
2020/02/23 Javascript
微信小程序中的上拉、下拉菜单功能
2020/03/13 Javascript
jQuery 动态粒子效果示例代码
2020/07/07 jQuery
python数据结构之二叉树的建立实例
2014/04/29 Python
Python判断Abundant Number的方法
2015/06/15 Python
Python cookbook(数据结构与算法)将序列分解为单独变量的方法
2018/02/13 Python
Python使用numpy模块创建数组操作示例
2018/06/20 Python
django数据关系一对多、多对多模型、自关联的建立
2019/07/24 Python
详解Python利用configparser对配置文件进行读写操作
2020/11/03 Python
python 基于pygame实现俄罗斯方块
2021/03/02 Python
中国综合性网上购物商城:当当(网上卖书起家)
2016/11/16 全球购物
波兰品牌内衣及泳装网上商店:Astratex.pl
2017/02/03 全球购物
美国最大的户外装备和服装购物网站:Backcountry
2019/10/15 全球购物
adidas瑞典官方网站:购买阿迪达斯鞋子和运动服
2019/12/11 全球购物
linux面试题参考答案(3)
2012/09/13 面试题
部门年终奖分配方案
2014/05/07 职场文书
大学教师个人总结
2015/02/10 职场文书
西柏坡观后感
2015/06/08 职场文书
优质服务标语口号
2015/12/26 职场文书
入党转正申请书范文
2019/05/20 职场文书
2019让人心动的商业计划书
2019/06/27 职场文书
Nginx反爬虫策略,防止UA抓取网站
2021/03/31 Servers
Apache Hudi 加速传统的批处理模式
2022/04/24 Servers