浅谈tensorflow模型保存为pb的各种姿势


Posted in Python onMay 25, 2020

一,直接保存pb

1, 首先我们当然可以直接在tensorflow训练中直接保存为pb为格式,保存pb的好处就是使用场景是实现创建模型与使用模型的解耦,使得创建模型与使用模型的解耦,使得前向推导inference代码统一。另外的好处就是保存为pb的时候,模型的变量会变成固定的,导致模型的大小会大大减小。

这里稍稍解释下pb:是MetaGraph的protocol buffer格式的文件,MetaGraph包括计算图,数据流,以及相关的变量和输入输出

主要使用tf.SavedModelBuilder来完成这个工作,并且可以把多个计算图保存到一个pb文件中,如果有多个MetaGraph,那么只会保留第一个MetaGraph的版本号。

保持pb的文件代码:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32, name='x')
 y = tf.placeholder(tf.int32, name='y')
 b = tf.Variable(1, name='b')
 xy = tf.multiply(x, y)
 # 这里的输出需要加上name属性
 op = tf.add(xy, b, name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 
 # 测试 OP
 feed_dict = {x: 10, y: 3}
 print(sess.run(op, feed_dict))
 
 # 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # 输出
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31

其实主要是:

# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
# 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())

1.1 加载测试代码

from tensorflow.python.platform import gfile
 
sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') # 导入计算图
 
# 需要有一个初始化的过程 
sess.run(tf.global_variables_initializer())
 
# 需要先复原变量
print(sess.run('b:0'))
# 1
 
# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
 
op = sess.graph.get_tensor_by_name('op_to_store:0')
 
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26

2,第二种就是采用上述的那API来进行保存

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32, name='x')
 y = tf.placeholder(tf.int32, name='y')
 b = tf.Variable(1, name='b')
 xy = tf.multiply(x, y)
 # 这里的输出需要加上name属性
 op = tf.add(xy, b, name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 
 # 测试 OP
 feed_dict = {x: 10, y: 3}
 print(sess.run(op, feed_dict))
 
 # 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31
 
 
 # 官网有误,写成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
 # 输入输出信息字典,额外的信息
 builder.add_meta_graph_and_variables(sess,
          ['cpu_server_1'])
 
# 添加第二个 MetaGraphDef 
#with tf.Session(graph=tf.Graph()) as sess:
# ...
# builder.add_meta_graph([tag_constants.SERVING])
#...
 
builder.save() # 保存 PB 模型

核心就是采用了:

# 官网有误,写成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
 # 输入输出信息字典,额外的信息
 builder.add_meta_graph_and_variables(sess,
          ['cpu_server_1'])

2.1 对应的测试代码为:

with tf.Session(graph=tf.Graph()) as sess:
 tf.saved_model.loader.load(sess, ['cpu_1'], pb_file_path+'savemodel')
 sess.run(tf.global_variables_initializer())
 
 input_x = sess.graph.get_tensor_by_name('x:0')
 input_y = sess.graph.get_tensor_by_name('y:0')
 
 op = sess.graph.get_tensor_by_name('op_to_store:0')
 
 ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
 print(ret)
# 只需要指定要恢复模型的 session,模型的 tag,模型的保存路径即可,使用起来更加简单

这样和之前的导入pb模型一样,也是要知道tensor的name,那么如何在不知道tensor name的情况下使用呢,给add_meta_graph_and_variables方法传入第三个参数,signature_def_map即可。

二,从ckpt进行加载

使用tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
 saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

浅谈tensorflow模型保存为pb的各种姿势

checkpoint是检查点的文件,文件保存了一个目录下所有的模型文件列表

model.ckpt.meta文件保存了Tensorflow计算图的结果,可以理解为神经网络的网络结构,该文件可以被tf.train.import_meta_graph加载到当前默认的图来使用

ckpt.data是保存模型中每个变量的取值

方法一, tensorflow提供了convert_variables_to_constants()方法,改方法可以固化模型结构,将计算图中的变量取值以常量的形式保存

ckpt转换pb格式过程如下:

1,通过传入ckpt模型的路径得到模型的图和变量数据

2,通过import_meta_graph导入模型中的图

3,通过saver.restore从模型中恢复图中各个变量的数据

4,通过graph_util.convert_variables_to_constants将模型持久化

import tensorflow as tf 
from tensorflow.python.framework import graph_util
from tensorflow.pyton.platform import gfile
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=input_graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in graph.get_operations():
  #  print(op.name, op.values())

函数freeze_graph中,最重要的就是指定输出节点的名称,这个节点名称是原模型存在的结点,注意节点名称与张量名称的区别:

如:“input:0”是张量的名称,而“input”表示的是节点的名称

源码中通过graph = tf.get_default_graph()获得默认图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此就必须执行tf.train.import_meta_graph,再执行tf.get_default_graph()

1.2 一个小工具

tensorflow打印pb模型的所有节点

from tensorflow.python.framework import tensor_util
from google.protobuf import text_format 
import tensorflow as tf 
from tensorflow.python.platform import gfile 
from tensorflow.python.framework import tensor_util
 
pb_path = './model.pb'
 
with tf.Session() as sess:
 with gfile.FastGFile(pb_path,'rb') as f:
  graph_def = tf.GraphDef()
 
  graph_def.ParseFromString(f.read())
  tf.import_graph_def(graph_def,name='')
  for i,n in enumerate(graph_def.node):
   print("Name of the node -%s"%n.name)
tensorflow打印ckpt的所有节点

from tensorflow.python import pywrap_tensorflow
checkpoint_path = './_checkpoint/hed.ckpt-130'
 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
 print("tensor_name:",key)

方法二,除了上述办法外还有一种是需要通过源码的,这样既可以得到输出节点,还可以自定义输入节点。

import tensorflow as tf 
 
def model(input):
 net = tf.layers.conv2d(input,filters=32,kernel_size=3)
 net = tf.layers.batch_normalization(net,fused=False)
 net = tf.layers.separable_conv2d(net,32,3)
 net = tf.layers.conv2d(net,filters=32,kernel_size=3,name='output')
 
 return net 
 
input_node = tf.placeholder(tf.float32,[1,480,480,3],name = 'image')
output_node_names = 'head_neck_count/BiasAdd'
ckpt = ckpt_path 
pb = pb_path 
 
with tf.Session() as sess:
 model1 = model(input_node)
 sess.run(tf.global_variables_initializer())
 output_node_names = 'output/BiasAdd'
 
 input_graph_def = tf.get_default_graph().as_graph_def()
 output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names.split(','))
 
with tf.gfile.GFile(pb,'wb') as f:
 f.write(output_graph_def.SerializeToString())

注意:

节点名称和张量名称区别

类似于output是节点名称

类似于output:0是张量名称

方法三,其实是方法一的延伸可以配合tensorflow自带的一些工具来进行完成

freeze_graph

总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):

1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)

2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。

3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False

4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。

5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。

6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all

7、filename_tensor_name:(可选)已弃用。默认:save/Const:0

8、output_graph:(必选)用来保存整合后的模型输出文件。

9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)

10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。

所以还是建议选择方法三

导出pb后的测试代码如下:下图是比较完成的测试代码与导出代码。

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2018-08-29 17:46:50
 @info :
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3
 
def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()
  with open(pb_path, "rb") as f:
   output_graph_def.ParseFromString(f.read())
   tf.import_graph_def(output_graph_def, name="")
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
 
   # 定义输入的张量名称,对应网络结构的输入张量
   # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
   input_image_tensor = sess.graph.get_tensor_by_name("input:0")
   input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
   input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
   # 定义输出的张量名称
   output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
   # 读取测试图片
   im=read_image(image_path,resize_height,resize_width,normalization=True)
   im=im[np.newaxis,:]
   # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
   # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
   out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
              input_keep_prob_tensor:1.0,
              input_is_training_tensor:False})
   print("out:{}".format(out))
   score = tf.nn.softmax(out, name='pre')
   class_id = tf.argmax(score, 1)
   print "pre class_id:{}".format(sess.run(class_id))
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=sess.graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in sess.graph.get_operations():
  #  print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=input_graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in graph.get_operations():
  #  print(op.name, op.values())
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 测试pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

以上这篇浅谈tensorflow模型保存为pb的各种姿势就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python在不同目录下导入模块的实现方法
Oct 27 Python
python kmeans聚类简单介绍和实现代码
Feb 23 Python
python3 图片referer防盗链的实现方法
Mar 12 Python
TensorFlow 合并/连接数组的方法
Jul 27 Python
对Python 窗体(tkinter)文本编辑器(Text)详解
Oct 11 Python
Python实现繁体中文与简体中文相互转换的方法示例
Dec 18 Python
pandas DataFrame创建方法的方式
Aug 02 Python
Python学习笔记之列表和成员运算符及列表相关方法详解
Aug 22 Python
Python Opencv提取图片中某种颜色组成的图形的方法
Sep 19 Python
Python使用ElementTree美化XML格式的操作
Mar 06 Python
python爬虫多次请求超时的几种重试方法(6种)
Dec 01 Python
Django drf请求模块源码解析
Jun 08 Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
keras 如何保存最佳的训练模型
May 25 #Python
keras处理欠拟合和过拟合的实例讲解
May 25 #Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
You might like
一步一步学习PHP(6) 面向对象
2010/02/16 PHP
Windows下利用Gvim写PHP产生中文乱码问题解决方法
2011/04/20 PHP
九个你必须知道而且又很好用的php函数和特点
2013/08/08 PHP
php判断输入是否是纯数字,英文,汉字的方法
2015/03/05 PHP
详解 PHP加密解密字符串函数附源码下载
2015/12/18 PHP
PHP里的$_GET数组介绍
2019/03/22 PHP
Yii框架Session与Cookie使用方法示例
2019/10/14 PHP
学习jquery必备 api中英文对照的chm手册 下载
2007/05/03 Javascript
JSON JQUERY模板实现说明
2010/07/03 Javascript
Javascript无参数和有参数类继承问题解决方法
2015/03/02 Javascript
BootStrap与validator 使用笔记(JAVA SpringMVC实现)
2016/09/21 Javascript
jQuery+PHP+Mysql实现抽奖程序
2020/04/12 jQuery
Angular2 父子组件通信方式的示例
2018/01/29 Javascript
JavaScript实现计算圆周率到小数点后100位的方法示例
2018/05/08 Javascript
在vue项目中引入highcharts图表的方法
2019/01/21 Javascript
Vue 中 template 有且只能一个 root的原因解析(源码分析)
2020/04/11 Javascript
JS实现简易贪吃蛇游戏
2020/08/24 Javascript
python线程池的实现实例
2013/11/18 Python
分享一下Python 开发者节省时间的10个方法
2015/10/02 Python
完美解决python遍历删除字典里值为空的元素报错问题
2016/09/11 Python
解决pip install的时候报错timed out的问题
2018/06/12 Python
Python enumerate函数功能与用法示例
2019/03/01 Python
python同步windows和linux文件
2019/08/29 Python
详解Python+Selenium+ChromeDriver的配置和问题解决
2021/01/19 Python
Jacadi Paris美国官方网站:法国童装品牌
2017/10/15 全球购物
北京天润融通.net面试题笔试题
2012/02/20 面试题
超市促销实习自我鉴定
2013/09/23 职场文书
报关专员求职信范文
2014/02/22 职场文书
结婚喜宴主持词
2014/03/14 职场文书
2014年优秀党员材料
2014/12/18 职场文书
幼儿园安全教育月活动总结
2015/05/08 职场文书
2015年为民办实事工作总结
2015/05/26 职场文书
小学英语听课心得体会
2016/01/14 职场文书
何时使用Map来代替普通的JS对象
2021/04/29 Javascript
Redis过期数据是否会被立马删除
2022/07/23 Redis
Python中np.random.randint()参数详解及用法实例
2022/09/23 Python