浅谈tensorflow模型保存为pb的各种姿势


Posted in Python onMay 25, 2020

一,直接保存pb

1, 首先我们当然可以直接在tensorflow训练中直接保存为pb为格式,保存pb的好处就是使用场景是实现创建模型与使用模型的解耦,使得创建模型与使用模型的解耦,使得前向推导inference代码统一。另外的好处就是保存为pb的时候,模型的变量会变成固定的,导致模型的大小会大大减小。

这里稍稍解释下pb:是MetaGraph的protocol buffer格式的文件,MetaGraph包括计算图,数据流,以及相关的变量和输入输出

主要使用tf.SavedModelBuilder来完成这个工作,并且可以把多个计算图保存到一个pb文件中,如果有多个MetaGraph,那么只会保留第一个MetaGraph的版本号。

保持pb的文件代码:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32, name='x')
 y = tf.placeholder(tf.int32, name='y')
 b = tf.Variable(1, name='b')
 xy = tf.multiply(x, y)
 # 这里的输出需要加上name属性
 op = tf.add(xy, b, name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 
 # 测试 OP
 feed_dict = {x: 10, y: 3}
 print(sess.run(op, feed_dict))
 
 # 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # 输出
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31

其实主要是:

# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
# 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())

1.1 加载测试代码

from tensorflow.python.platform import gfile
 
sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') # 导入计算图
 
# 需要有一个初始化的过程 
sess.run(tf.global_variables_initializer())
 
# 需要先复原变量
print(sess.run('b:0'))
# 1
 
# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
 
op = sess.graph.get_tensor_by_name('op_to_store:0')
 
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26

2,第二种就是采用上述的那API来进行保存

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32, name='x')
 y = tf.placeholder(tf.int32, name='y')
 b = tf.Variable(1, name='b')
 xy = tf.multiply(x, y)
 # 这里的输出需要加上name属性
 op = tf.add(xy, b, name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 
 # 测试 OP
 feed_dict = {x: 10, y: 3}
 print(sess.run(op, feed_dict))
 
 # 写入序列化的 PB 文件
 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31
 
 
 # 官网有误,写成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
 # 输入输出信息字典,额外的信息
 builder.add_meta_graph_and_variables(sess,
          ['cpu_server_1'])
 
# 添加第二个 MetaGraphDef 
#with tf.Session(graph=tf.Graph()) as sess:
# ...
# builder.add_meta_graph([tag_constants.SERVING])
#...
 
builder.save() # 保存 PB 模型

核心就是采用了:

# 官网有误,写成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
 # 输入输出信息字典,额外的信息
 builder.add_meta_graph_and_variables(sess,
          ['cpu_server_1'])

2.1 对应的测试代码为:

with tf.Session(graph=tf.Graph()) as sess:
 tf.saved_model.loader.load(sess, ['cpu_1'], pb_file_path+'savemodel')
 sess.run(tf.global_variables_initializer())
 
 input_x = sess.graph.get_tensor_by_name('x:0')
 input_y = sess.graph.get_tensor_by_name('y:0')
 
 op = sess.graph.get_tensor_by_name('op_to_store:0')
 
 ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
 print(ret)
# 只需要指定要恢复模型的 session,模型的 tag,模型的保存路径即可,使用起来更加简单

这样和之前的导入pb模型一样,也是要知道tensor的name,那么如何在不知道tensor name的情况下使用呢,给add_meta_graph_and_variables方法传入第三个参数,signature_def_map即可。

二,从ckpt进行加载

使用tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
 saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

浅谈tensorflow模型保存为pb的各种姿势

checkpoint是检查点的文件,文件保存了一个目录下所有的模型文件列表

model.ckpt.meta文件保存了Tensorflow计算图的结果,可以理解为神经网络的网络结构,该文件可以被tf.train.import_meta_graph加载到当前默认的图来使用

ckpt.data是保存模型中每个变量的取值

方法一, tensorflow提供了convert_variables_to_constants()方法,改方法可以固化模型结构,将计算图中的变量取值以常量的形式保存

ckpt转换pb格式过程如下:

1,通过传入ckpt模型的路径得到模型的图和变量数据

2,通过import_meta_graph导入模型中的图

3,通过saver.restore从模型中恢复图中各个变量的数据

4,通过graph_util.convert_variables_to_constants将模型持久化

import tensorflow as tf 
from tensorflow.python.framework import graph_util
from tensorflow.pyton.platform import gfile
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=input_graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in graph.get_operations():
  #  print(op.name, op.values())

函数freeze_graph中,最重要的就是指定输出节点的名称,这个节点名称是原模型存在的结点,注意节点名称与张量名称的区别:

如:“input:0”是张量的名称,而“input”表示的是节点的名称

源码中通过graph = tf.get_default_graph()获得默认图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此就必须执行tf.train.import_meta_graph,再执行tf.get_default_graph()

1.2 一个小工具

tensorflow打印pb模型的所有节点

from tensorflow.python.framework import tensor_util
from google.protobuf import text_format 
import tensorflow as tf 
from tensorflow.python.platform import gfile 
from tensorflow.python.framework import tensor_util
 
pb_path = './model.pb'
 
with tf.Session() as sess:
 with gfile.FastGFile(pb_path,'rb') as f:
  graph_def = tf.GraphDef()
 
  graph_def.ParseFromString(f.read())
  tf.import_graph_def(graph_def,name='')
  for i,n in enumerate(graph_def.node):
   print("Name of the node -%s"%n.name)
tensorflow打印ckpt的所有节点

from tensorflow.python import pywrap_tensorflow
checkpoint_path = './_checkpoint/hed.ckpt-130'
 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
 print("tensor_name:",key)

方法二,除了上述办法外还有一种是需要通过源码的,这样既可以得到输出节点,还可以自定义输入节点。

import tensorflow as tf 
 
def model(input):
 net = tf.layers.conv2d(input,filters=32,kernel_size=3)
 net = tf.layers.batch_normalization(net,fused=False)
 net = tf.layers.separable_conv2d(net,32,3)
 net = tf.layers.conv2d(net,filters=32,kernel_size=3,name='output')
 
 return net 
 
input_node = tf.placeholder(tf.float32,[1,480,480,3],name = 'image')
output_node_names = 'head_neck_count/BiasAdd'
ckpt = ckpt_path 
pb = pb_path 
 
with tf.Session() as sess:
 model1 = model(input_node)
 sess.run(tf.global_variables_initializer())
 output_node_names = 'output/BiasAdd'
 
 input_graph_def = tf.get_default_graph().as_graph_def()
 output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names.split(','))
 
with tf.gfile.GFile(pb,'wb') as f:
 f.write(output_graph_def.SerializeToString())

注意:

节点名称和张量名称区别

类似于output是节点名称

类似于output:0是张量名称

方法三,其实是方法一的延伸可以配合tensorflow自带的一些工具来进行完成

freeze_graph

总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):

1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)

2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。

3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False

4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。

5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。

6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all

7、filename_tensor_name:(可选)已弃用。默认:save/Const:0

8、output_graph:(必选)用来保存整合后的模型输出文件。

9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)

10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。

所以还是建议选择方法三

导出pb后的测试代码如下:下图是比较完成的测试代码与导出代码。

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2018-08-29 17:46:50
 @info :
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3
 
def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()
  with open(pb_path, "rb") as f:
   output_graph_def.ParseFromString(f.read())
   tf.import_graph_def(output_graph_def, name="")
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
 
   # 定义输入的张量名称,对应网络结构的输入张量
   # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
   input_image_tensor = sess.graph.get_tensor_by_name("input:0")
   input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
   input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
   # 定义输出的张量名称
   output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
   # 读取测试图片
   im=read_image(image_path,resize_height,resize_width,normalization=True)
   im=im[np.newaxis,:]
   # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
   # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
   out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
              input_keep_prob_tensor:1.0,
              input_is_training_tensor:False})
   print("out:{}".format(out))
   score = tf.nn.softmax(out, name='pre')
   class_id = tf.argmax(score, 1)
   print "pre class_id:{}".format(sess.run(class_id))
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=sess.graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in sess.graph.get_operations():
  #  print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
   sess=sess,
   input_graph_def=input_graph_def,# 等于:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
  # for op in graph.get_operations():
  #  print(op.name, op.values())
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 测试pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

以上这篇浅谈tensorflow模型保存为pb的各种姿势就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自定义类并使用的方法
May 07 Python
Python数据库的连接实现方法与注意事项
Feb 27 Python
Python实现霍夫圆和椭圆变换代码详解
Jan 12 Python
python range()函数取反序遍历sequence的方法
Jun 25 Python
Python3网络爬虫中的requests高级用法详解
Jun 18 Python
使用Rasterio读取栅格数据的实例讲解
Nov 26 Python
如何在django中添加日志功能
Feb 06 Python
Python3.9又更新了:dict内置新功能
Feb 28 Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 Python
Python模拟登录和登录跳转的参考示例
Oct 30 Python
python多线程爬取西刺代理的示例代码
Jan 30 Python
python本地文件服务器实例教程
May 02 Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
keras 如何保存最佳的训练模型
May 25 #Python
keras处理欠拟合和过拟合的实例讲解
May 25 #Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
You might like
实用函数4
2007/11/08 PHP
用mysql触发器自动更新memcache的实现代码
2009/10/11 PHP
phpmyadmin里面导入sql语句格式的大量数据的方法
2010/06/05 PHP
PHP文章按日期(月日)SQL归档语句
2012/11/29 PHP
解析PHP获取当前网址及域名的实现代码
2013/06/23 PHP
PHP 过滤页面中的BOM(实现代码)
2013/06/29 PHP
php实现从上传文件创建缩略图的方法
2015/04/02 PHP
Android AsyncTack 异步任务实例详解
2016/11/02 PHP
php实用代码片段整理
2016/11/12 PHP
php+js实现百度地图多点标注的方法
2016/11/30 PHP
php文件管理基本功能简单操作
2017/01/16 PHP
Laravel配合jwt使用的方法实例
2020/10/25 PHP
jquery命令汇总,方便使用jquery的朋友
2012/06/26 Javascript
jquerydom对象的事件隐藏显示和对象数组示例
2013/12/10 Javascript
浅谈js中子页面父页面方法 变量相互调用
2016/08/04 Javascript
jquery.validate.js 多个相同name的处理方式
2017/07/10 jQuery
Vue自定义组件的四种方式示例详解
2020/02/28 Javascript
JS手写一个自定义Promise操作示例
2020/03/16 Javascript
js实现简单的无缝轮播效果
2020/09/05 Javascript
python打开url并按指定块读取网页内容的方法
2015/04/29 Python
举例讲解Python的Tornado框架实现数据可视化的教程
2015/05/02 Python
Python中用post、get方式提交数据的方法示例
2017/09/22 Python
Python实现将doc转化pdf格式文档的方法
2018/01/19 Python
python 实现让字典的value 成为列表
2019/12/16 Python
Python中import导入不同目录的模块方法详解
2020/02/18 Python
Python能做什么
2020/06/02 Python
Python使用socket_TCP实现小文件下载功能
2020/10/09 Python
德国团购网站:Groupon德国
2018/03/13 全球购物
领导干部廉政承诺书
2014/03/27 职场文书
论文指导教师评语
2014/04/28 职场文书
团日活动总结书格式
2014/05/08 职场文书
幼儿园教师工作总结2015
2015/04/02 职场文书
2016年优秀党务工作者先进事迹材料
2016/02/29 职场文书
解决golang在import自己的包报错的问题
2021/04/29 Golang
linux下安装redis图文详细步骤
2021/12/04 Redis
利用Python实时获取steam特惠游戏数据
2022/06/25 Python