tensorflow实现将ckpt转pb文件的方法


Posted in Python onApril 22, 2020

   本博客实现将自己训练保存的ckpt模型转换为pb文件,该方法适用于任何ckpt模型,当然你需要确定ckpt模型输入/输出的节点名称。

   使用 tf.train.saver()保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方法是在TensorFlow中是最常用的保存方式。

    例如:下面的代码运行后,会在save目录下保存了四个文件:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
 saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

    其中,checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
ckpt.data : 保存模型中每个变量的取值
   但很多时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

    TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存,而且保存的模型可以移植到Android平台。

一、CKPT 转换成 PB格式

    将CKPT 转换成 PB格式的文件的过程可简述如下:

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化
 下面的CKPT 转换成 PB格式例子,是我训练GoogleNet InceptionV3模型保存的ckpt转pb文件的例子,训练过程可参考博客:《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》:

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=input_graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())

说明:

1、函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。因为网络其实是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

 2、在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别,例如:“input:0”是张量的名称,而"input"表示的是节点的名称。

3、源码中通过graph = tf.get_default_graph()获得默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。

4、实质上,我们可以直接在恢复的会话sess中,获得默认的网络图,更简单的方法,如下:

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

调用方法很简单,输入ckpt模型路径,输出pb模型的路径即可:

    # 输入ckpt模型路径
    input_checkpoint='models/model.ckpt-10000'
    # 输出pb模型的路径
    out_pb_path="models/pb/frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint,out_pb_path)

5、上面以及说明:在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。因此,其他网络模型,也可以通过简单的修改输出的节点名称output_node_names,将ckpt转为pb文件 。

       PS:注意节点名称,应包含name_scope 和 variable_scope命名空间,并用“/”隔开,如"InceptionV3/Logits/SpatialSqueeze"

二、 pb模型预测

    下面是预测pb模型的代码

def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 # 定义输入的张量名称,对应网络结构的输入张量
 # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
 input_image_tensor = sess.graph.get_tensor_by_name("input:0")
 input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
 input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
 # 定义输出的张量名称
 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
 # 读取测试图片
 im=read_image(image_path,resize_height,resize_width,normalization=True)
 im=im[np.newaxis,:]
 # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
 out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
 input_keep_prob_tensor:1.0,
 input_is_training_tensor:False})
 print("out:{}".format(out))
 score = tf.nn.softmax(out, name='pre')
 class_id = tf.argmax(score, 1)
 print "pre class_id:{}".format(sess.run(class_id))

说明:

1、与ckpt预测不同的是,pb文件已经固化了网络模型结构,因此,即使不知道原训练模型(train)的源码,我们也可以恢复网络图,并进行预测。恢复模型十分简单,只需要从读取的序列化数据中导入网络结构即可:

tf.import_graph_def(output_graph_def, name="")
2、但必须知道原网络模型的输入和输出的节点名称(当然了,传递数据时,是通过输入输出的张量来完成的)。由于InceptionV3模型的输入有三个节点,因此这里需要定义输入的张量名称,它对应网络结构的输入张量:

input_image_tensor = sess.graph.get_tensor_by_name("input:0")
input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
以及输出的张量名称:

output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

3、预测时,需要feed输入数据:

# 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
# out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                            input_keep_prob_tensor:1.0,
                                            input_is_training_tensor:False})

4、其他网络模型预测时,也可以通过修改输入和输出的张量的名称 。

       PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"InceptionV3/Logits/SpatialSqueeze:0"

完整的CKPT 转换成 PB格式和预测的代码如下:

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2018-08-29 17:46:50
 @info :
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3
 
def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 # 定义输入的张量名称,对应网络结构的输入张量
 # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
 input_image_tensor = sess.graph.get_tensor_by_name("input:0")
 input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
 input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
 # 定义输出的张量名称
 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
 # 读取测试图片
 im=read_image(image_path,resize_height,resize_width,normalization=True)
 im=im[np.newaxis,:]
 # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
 out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
 input_keep_prob_tensor:1.0,
 input_is_training_tensor:False})
 print("out:{}".format(out))
 score = tf.nn.softmax(out, name='pre')
 class_id = tf.argmax(score, 1)
 print "pre class_id:{}".format(sess.run(class_id))
 
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in sess.graph.get_operations():
 # print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=input_graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())
 
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 测试pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

三、源码下载和资料推荐

    1、训练方法
     上面的CKPT 转换成 PB格式例子,是我训练GoogleNet InceptionV3模型保存的ckpt转pb文件的例子,训练过程可参考博客:

《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》:https://blog.csdn.net/guyuealian/article/details/81560537

    2、Github地址
Github源码:https://github.com/PanJinquan/tensorflow_models_nets  中的convert_pb.py文件

预训练模型下载地址:http://xiazai.3water.com/202004/yuanma/googlenet_inception_3water.rar

    3、将模型移植Android的方法
     pb文件是可以移植到Android平台运行的,其方法,可参考:

《将tensorflow训练好的模型移植到Android (MNIST手写数字识别)》

参考:

[1] https://3water.com/article/185209.htm

【2】https://3water.com/article/185206.htm

到此这篇关于tensorflow实现将ckpt转pb文件的方法的文章就介绍到这了,更多相关tensorflow ckpt转pb文件内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
编写Python的web框架中的Model的教程
Apr 29 Python
python实现图书馆研习室自动预约功能
Apr 27 Python
Python 脚本获取ES 存储容量的实例
Dec 27 Python
对python for 文件指定行读写操作详解
Dec 29 Python
python自动化之Ansible的安装教程
Jun 13 Python
python如何以表格形式打印输出的方法示例
Jun 21 Python
python list转置和前后反转的例子
Aug 26 Python
Python3视频转字符动画的实例代码
Aug 29 Python
python3实现网页版raspberry pi(树莓派)小车控制
Feb 12 Python
python 已知一个字符,在一个list中找出近似值或相似值实现模糊匹配
Feb 29 Python
keras模型保存为tensorflow的二进制模型方式
May 25 Python
keras实现多GPU或指定GPU的使用介绍
Jun 17 Python
jupyter lab文件导出/下载方式
Apr 22 #Python
python模拟实现分发扑克牌
Apr 22 #Python
tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)
Apr 22 #Python
有趣的Python图片制作之如何用QQ好友头像拼接出里昂
Apr 22 #Python
python模拟斗地主发牌
Apr 22 #Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
Apr 22 #Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 #Python
You might like
人尽可用的Windows技巧小贴士之下篇
2007/03/22 PHP
PHP中几种常见的超时处理全面总结
2012/09/11 PHP
基于php在各种web服务器的运行模式详解
2013/06/03 PHP
Thinkphp5框架实现图片、音频和视频文件的上传功能详解
2019/08/27 PHP
javascript 常用方法总结
2009/06/03 Javascript
JQuery 学习笔记 选择器之四
2009/07/23 Javascript
JavaScript小技巧 2.5 则
2010/09/12 Javascript
当达到输入长度时表单自动切换焦点
2014/04/06 Javascript
js实现简单秒表走动的时钟特效
2020/03/25 Javascript
JavaScript实现点击单选按钮改变输入框中文本域内容的方法
2015/08/12 Javascript
Javascript验证方法大全
2015/09/21 Javascript
Jquery获取当前城市的天气信息
2016/08/05 Javascript
AngularJs 国际化(I18n/L10n)详解
2016/09/01 Javascript
JS基于onclick事件实现单个按钮的编辑与保存功能示例
2017/02/13 Javascript
Bootstrap modal 多弹窗之叠加引起的滚动条遮罩阴影问题
2017/02/27 Javascript
vue实现移动端图片裁剪上传功能
2020/08/18 Javascript
JS实现的哈夫曼编码示例【原始版与修改版】
2018/04/22 Javascript
AngularJs的UI组件ui-Bootstrap之Tooltip和Popover
2018/07/13 Javascript
Vue头像处理方案小结
2018/07/26 Javascript
JS中数据结构与算法---排序算法(Sort Algorithm)实例详解
2019/06/17 Javascript
[14:03]2017DOTA2亚洲邀请赛开幕式:12神兵演绎水墨中华
2017/04/01 DOTA
在Python中使用CasperJS获取JS渲染生成的HTML内容的教程
2015/04/09 Python
TensorBoard 计算图的查看方式
2020/02/15 Python
解决pyqt5异常退出无提示信息的问题
2020/04/08 Python
python的json包位置及用法总结
2020/06/21 Python
navabi英国:设计师大码女装
2019/06/25 全球购物
美国椅子和沙发制造商:La-Z-Boy
2020/10/25 全球购物
加拿大户外探险购物网站:SAIL
2020/06/27 全球购物
应届生会计电算化求职信
2013/10/03 职场文书
有针对性的求职自荐信
2013/11/14 职场文书
公司合作协议书范本
2014/04/18 职场文书
乡镇综治宣传月活动总结
2014/07/02 职场文书
大学生入党积极分子党校学习思想汇报
2014/10/25 职场文书
乡镇党建工作汇报材料
2014/10/27 职场文书
2014年班务工作总结
2014/12/02 职场文书
Android开发手册Chip监听及ChipGroup监听
2022/06/10 Java/Android