tensorflow实现将ckpt转pb文件的方法


Posted in Python onApril 22, 2020

   本博客实现将自己训练保存的ckpt模型转换为pb文件,该方法适用于任何ckpt模型,当然你需要确定ckpt模型输入/输出的节点名称。

   使用 tf.train.saver()保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方法是在TensorFlow中是最常用的保存方式。

    例如:下面的代码运行后,会在save目录下保存了四个文件:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
 saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

    其中,checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
ckpt.data : 保存模型中每个变量的取值
   但很多时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

    TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存,而且保存的模型可以移植到Android平台。

一、CKPT 转换成 PB格式

    将CKPT 转换成 PB格式的文件的过程可简述如下:

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化
 下面的CKPT 转换成 PB格式例子,是我训练GoogleNet InceptionV3模型保存的ckpt转pb文件的例子,训练过程可参考博客:《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》:

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=input_graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())

说明:

1、函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。因为网络其实是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

 2、在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别,例如:“input:0”是张量的名称,而"input"表示的是节点的名称。

3、源码中通过graph = tf.get_default_graph()获得默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。

4、实质上,我们可以直接在恢复的会话sess中,获得默认的网络图,更简单的方法,如下:

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

调用方法很简单,输入ckpt模型路径,输出pb模型的路径即可:

    # 输入ckpt模型路径
    input_checkpoint='models/model.ckpt-10000'
    # 输出pb模型的路径
    out_pb_path="models/pb/frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint,out_pb_path)

5、上面以及说明:在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。因此,其他网络模型,也可以通过简单的修改输出的节点名称output_node_names,将ckpt转为pb文件 。

       PS:注意节点名称,应包含name_scope 和 variable_scope命名空间,并用“/”隔开,如"InceptionV3/Logits/SpatialSqueeze"

二、 pb模型预测

    下面是预测pb模型的代码

def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 # 定义输入的张量名称,对应网络结构的输入张量
 # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
 input_image_tensor = sess.graph.get_tensor_by_name("input:0")
 input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
 input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
 # 定义输出的张量名称
 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
 # 读取测试图片
 im=read_image(image_path,resize_height,resize_width,normalization=True)
 im=im[np.newaxis,:]
 # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
 out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
 input_keep_prob_tensor:1.0,
 input_is_training_tensor:False})
 print("out:{}".format(out))
 score = tf.nn.softmax(out, name='pre')
 class_id = tf.argmax(score, 1)
 print "pre class_id:{}".format(sess.run(class_id))

说明:

1、与ckpt预测不同的是,pb文件已经固化了网络模型结构,因此,即使不知道原训练模型(train)的源码,我们也可以恢复网络图,并进行预测。恢复模型十分简单,只需要从读取的序列化数据中导入网络结构即可:

tf.import_graph_def(output_graph_def, name="")
2、但必须知道原网络模型的输入和输出的节点名称(当然了,传递数据时,是通过输入输出的张量来完成的)。由于InceptionV3模型的输入有三个节点,因此这里需要定义输入的张量名称,它对应网络结构的输入张量:

input_image_tensor = sess.graph.get_tensor_by_name("input:0")
input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
以及输出的张量名称:

output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

3、预测时,需要feed输入数据:

# 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
# out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                            input_keep_prob_tensor:1.0,
                                            input_is_training_tensor:False})

4、其他网络模型预测时,也可以通过修改输入和输出的张量的名称 。

       PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"InceptionV3/Logits/SpatialSqueeze:0"

完整的CKPT 转换成 PB格式和预测的代码如下:

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2018-08-29 17:46:50
 @info :
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3
 
def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 # 定义输入的张量名称,对应网络结构的输入张量
 # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
 input_image_tensor = sess.graph.get_tensor_by_name("input:0")
 input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
 input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
 # 定义输出的张量名称
 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
 # 读取测试图片
 im=read_image(image_path,resize_height,resize_width,normalization=True)
 im=im[np.newaxis,:]
 # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
 out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
 input_keep_prob_tensor:1.0,
 input_is_training_tensor:False})
 print("out:{}".format(out))
 score = tf.nn.softmax(out, name='pre')
 class_id = tf.argmax(score, 1)
 print "pre class_id:{}".format(sess.run(class_id))
 
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in sess.graph.get_operations():
 # print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=input_graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())
 
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 测试pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

三、源码下载和资料推荐

    1、训练方法
     上面的CKPT 转换成 PB格式例子,是我训练GoogleNet InceptionV3模型保存的ckpt转pb文件的例子,训练过程可参考博客:

《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》:https://blog.csdn.net/guyuealian/article/details/81560537

    2、Github地址
Github源码:https://github.com/PanJinquan/tensorflow_models_nets  中的convert_pb.py文件

预训练模型下载地址:http://xiazai.3water.com/202004/yuanma/googlenet_inception_3water.rar

    3、将模型移植Android的方法
     pb文件是可以移植到Android平台运行的,其方法,可参考:

《将tensorflow训练好的模型移植到Android (MNIST手写数字识别)》

参考:

[1] https://3water.com/article/185209.htm

【2】https://3water.com/article/185206.htm

到此这篇关于tensorflow实现将ckpt转pb文件的方法的文章就介绍到这了,更多相关tensorflow ckpt转pb文件内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
零基础写python爬虫之使用urllib2组件抓取网页内容
Nov 04 Python
web.py 十分钟创建简易博客实现代码
Apr 22 Python
python获取list下标及其值的简单方法
Sep 12 Python
基于python爬虫数据处理(详解)
Jun 10 Python
Python3多线程操作简单示例
May 22 Python
推荐10款最受Python开发者欢迎的Python IDE
Sep 16 Python
Python实用库 PrettyTable 学习笔记
Aug 06 Python
Python 函数用法简单示例【定义、参数、返回值、函数嵌套】
Sep 20 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 Python
基于python3抓取pinpoint应用信息入库
Jan 08 Python
一些关于python 装饰器的个人理解
Aug 31 Python
python时间time模块处理大全
Oct 25 Python
jupyter lab文件导出/下载方式
Apr 22 #Python
python模拟实现分发扑克牌
Apr 22 #Python
tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)
Apr 22 #Python
有趣的Python图片制作之如何用QQ好友头像拼接出里昂
Apr 22 #Python
python模拟斗地主发牌
Apr 22 #Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
Apr 22 #Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 #Python
You might like
PHP 翻页 实例代码
2009/08/07 PHP
基于PHP对XML的操作详解
2013/06/07 PHP
Symfony2实现在doctrine中内置数据的方法
2016/02/05 PHP
php运行报错Call to undefined function curl_init()的最新解决方法
2016/11/20 PHP
php远程请求CURL实例教程(爬虫、保存登录状态)
2020/12/10 PHP
js 与或运算符 || && 妙用
2009/12/09 Javascript
js 获取坐标 通过JS得到当前焦点(鼠标)的坐标属性
2013/01/04 Javascript
js中top/parent/frame概述及案例应用
2013/02/06 Javascript
JS保留两位小数 四舍五入函数的小例子
2013/11/20 Javascript
JS动态创建DOM元素的方法
2015/06/09 Javascript
js注入 黑客之路必备!
2016/09/14 Javascript
微信小程序 canvas API详解及实例代码
2016/10/08 Javascript
[原创]javascript typeof id==='string'?document.getElementById(id):id解释
2016/11/02 Javascript
详解微信小程序开发之下拉刷新 上拉加载
2016/11/24 Javascript
jQuery 全选 全不选 事件绑定的实现代码
2017/01/23 Javascript
微信小程序 本地数据读取实例
2017/04/27 Javascript
vue中倒计时组件的实例代码
2018/07/06 Javascript
ES6 Map结构的应用实例分析
2019/06/26 Javascript
非常实用的jQuery代码段集锦【检测浏览器、滚动、复制、淡入淡出等】
2019/08/08 jQuery
如何在Vue中抽离接口配置文件
2019/10/31 Javascript
Vue 实现登录界面验证码功能
2020/01/03 Javascript
Python设置Socket代理及实现远程摄像头控制的例子
2015/11/13 Python
Python+Selenium自动化实现分页(pagination)处理
2017/03/31 Python
教大家玩转Python字符串处理的七种技巧
2017/03/31 Python
Python命令行解析模块详解
2018/02/01 Python
python 获取指定文件夹下所有文件名称并写入列表的实例
2018/04/23 Python
python在文本开头插入一行的实例
2018/05/02 Python
Python continue语句实例用法
2020/02/06 Python
浅谈Python中threading join和setDaemon用法及区别说明
2020/05/02 Python
CSS 说明横向进度条最后显示文字的实现代码
2020/11/10 HTML / CSS
web字体加载方案优化小结
2019/11/29 HTML / CSS
俄罗斯在线手表和珠宝商店:AllTime
2019/09/28 全球购物
澳大利亚排名第一的露营和户外设备在线零售商:Outbax
2020/05/06 全球购物
工程师岗位职责规定
2014/02/26 职场文书
《女娲补天》教学反思
2016/02/20 职场文书
致男子1500米运动员的广播稿
2019/11/08 职场文书