tensorflow实现将ckpt转pb文件的方法


Posted in Python onApril 22, 2020

   本博客实现将自己训练保存的ckpt模型转换为pb文件,该方法适用于任何ckpt模型,当然你需要确定ckpt模型输入/输出的节点名称。

   使用 tf.train.saver()保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方法是在TensorFlow中是最常用的保存方式。

    例如:下面的代码运行后,会在save目录下保存了四个文件:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
 saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

    其中,checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
ckpt.data : 保存模型中每个变量的取值
   但很多时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

    TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存,而且保存的模型可以移植到Android平台。

一、CKPT 转换成 PB格式

    将CKPT 转换成 PB格式的文件的过程可简述如下:

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化
 下面的CKPT 转换成 PB格式例子,是我训练GoogleNet InceptionV3模型保存的ckpt转pb文件的例子,训练过程可参考博客:《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》:

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=input_graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())

说明:

1、函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。因为网络其实是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

 2、在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别,例如:“input:0”是张量的名称,而"input"表示的是节点的名称。

3、源码中通过graph = tf.get_default_graph()获得默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。

4、实质上,我们可以直接在恢复的会话sess中,获得默认的网络图,更简单的方法,如下:

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

调用方法很简单,输入ckpt模型路径,输出pb模型的路径即可:

    # 输入ckpt模型路径
    input_checkpoint='models/model.ckpt-10000'
    # 输出pb模型的路径
    out_pb_path="models/pb/frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint,out_pb_path)

5、上面以及说明:在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。因此,其他网络模型,也可以通过简单的修改输出的节点名称output_node_names,将ckpt转为pb文件 。

       PS:注意节点名称,应包含name_scope 和 variable_scope命名空间,并用“/”隔开,如"InceptionV3/Logits/SpatialSqueeze"

二、 pb模型预测

    下面是预测pb模型的代码

def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 # 定义输入的张量名称,对应网络结构的输入张量
 # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
 input_image_tensor = sess.graph.get_tensor_by_name("input:0")
 input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
 input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
 # 定义输出的张量名称
 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
 # 读取测试图片
 im=read_image(image_path,resize_height,resize_width,normalization=True)
 im=im[np.newaxis,:]
 # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
 out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
 input_keep_prob_tensor:1.0,
 input_is_training_tensor:False})
 print("out:{}".format(out))
 score = tf.nn.softmax(out, name='pre')
 class_id = tf.argmax(score, 1)
 print "pre class_id:{}".format(sess.run(class_id))

说明:

1、与ckpt预测不同的是,pb文件已经固化了网络模型结构,因此,即使不知道原训练模型(train)的源码,我们也可以恢复网络图,并进行预测。恢复模型十分简单,只需要从读取的序列化数据中导入网络结构即可:

tf.import_graph_def(output_graph_def, name="")
2、但必须知道原网络模型的输入和输出的节点名称(当然了,传递数据时,是通过输入输出的张量来完成的)。由于InceptionV3模型的输入有三个节点,因此这里需要定义输入的张量名称,它对应网络结构的输入张量:

input_image_tensor = sess.graph.get_tensor_by_name("input:0")
input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
以及输出的张量名称:

output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

3、预测时,需要feed输入数据:

# 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
# out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                            input_keep_prob_tensor:1.0,
                                            input_is_training_tensor:False})

4、其他网络模型预测时,也可以通过修改输入和输出的张量的名称 。

       PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"InceptionV3/Logits/SpatialSqueeze:0"

完整的CKPT 转换成 PB格式和预测的代码如下:

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2018-08-29 17:46:50
 @info :
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3
 
def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 # 定义输入的张量名称,对应网络结构的输入张量
 # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
 input_image_tensor = sess.graph.get_tensor_by_name("input:0")
 input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
 input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
 # 定义输出的张量名称
 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
 # 读取测试图片
 im=read_image(image_path,resize_height,resize_width,normalization=True)
 im=im[np.newaxis,:]
 # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
 out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
 input_keep_prob_tensor:1.0,
 input_is_training_tensor:False})
 print("out:{}".format(out))
 score = tf.nn.softmax(out, name='pre')
 class_id = tf.argmax(score, 1)
 print "pre class_id:{}".format(sess.run(class_id))
 
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in sess.graph.get_operations():
 # print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=input_graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())
 
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 测试pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

三、源码下载和资料推荐

    1、训练方法
     上面的CKPT 转换成 PB格式例子,是我训练GoogleNet InceptionV3模型保存的ckpt转pb文件的例子,训练过程可参考博客:

《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》:https://blog.csdn.net/guyuealian/article/details/81560537

    2、Github地址
Github源码:https://github.com/PanJinquan/tensorflow_models_nets  中的convert_pb.py文件

预训练模型下载地址:http://xiazai.3water.com/202004/yuanma/googlenet_inception_3water.rar

    3、将模型移植Android的方法
     pb文件是可以移植到Android平台运行的,其方法,可参考:

《将tensorflow训练好的模型移植到Android (MNIST手写数字识别)》

参考:

[1] https://3water.com/article/185209.htm

【2】https://3water.com/article/185206.htm

到此这篇关于tensorflow实现将ckpt转pb文件的方法的文章就介绍到这了,更多相关tensorflow ckpt转pb文件内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中的localtime()方法使用详解
May 22 Python
python数据批量写入ScrolledText的优化方法
Oct 11 Python
Python中类的创建和实例化操作示例
Feb 27 Python
Django 中间键和上下文处理器的使用
Mar 17 Python
Python3 合并二叉树的实现
Sep 30 Python
Python列表倒序输出及其效率详解
Mar 04 Python
python实现贪吃蛇游戏源码
Mar 21 Python
利用python实现凯撒密码加解密功能
Mar 31 Python
用opencv给图片换背景色的示例代码
Jul 08 Python
关于python3.9安装wordcloud出错的问题及解决办法
Nov 02 Python
python实现图片,视频人脸识别(dlib版)
Nov 18 Python
详解Java中一维、二维数组在内存中的结构
Feb 11 Python
jupyter lab文件导出/下载方式
Apr 22 #Python
python模拟实现分发扑克牌
Apr 22 #Python
tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)
Apr 22 #Python
有趣的Python图片制作之如何用QQ好友头像拼接出里昂
Apr 22 #Python
python模拟斗地主发牌
Apr 22 #Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
Apr 22 #Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 #Python
You might like
2020显卡排行榜天梯图 显卡天梯图2020年3月最新版
2020/04/02 数码科技
PHP中用header图片地址 简单隐藏图片源地址
2008/04/09 PHP
检查用户名是否已在mysql中存在的php写法
2014/01/20 PHP
php的慢速日志引起的Mysql错误问题分析
2014/05/13 PHP
根据分辨率不同,调用不同的css文件
2006/08/25 Javascript
jQuery 幻灯片插件(带缩略图功能)
2011/01/24 Javascript
js 浏览器事件介绍
2012/03/30 Javascript
javascript设计模式 接口介绍
2012/07/24 Javascript
Javascript倒计时页面跳转实例小结
2013/09/11 Javascript
jquery限定文本框只能输入数字即整数和小数
2013/11/29 Javascript
零基础搭建Node.js、Express、Ejs、Mongodb服务器及应用开发入门
2014/12/20 Javascript
javascript操作符"!~"详解
2015/02/10 Javascript
JSON格式的时间/Date(2367828670431)/格式转为正常的年-月-日 格式的代码
2016/07/27 Javascript
使用ajaxfileupload.js实现上传文件功能
2016/08/13 Javascript
微信开发 JS-SDK 6.0.2 经常遇到问题总结
2016/12/08 Javascript
JS中如何实现点击a标签返回页面顶部的问题
2017/01/19 Javascript
jQuery Autocomplete简介_动力节点Java学院整理
2017/07/17 jQuery
React Native之ListView实现九宫格效果的示例
2017/08/02 Javascript
Vue组件通信实践记录(推荐)
2017/08/15 Javascript
Angular 4中如何显示内容的CSS样式示例代码
2017/11/06 Javascript
从理论角度讨论JavaScript闭包
2019/04/03 Javascript
新手如何快速理解js异步编程
2019/06/24 Javascript
[01:45]2014DOTA2 TI预选赛预选赛 大神专访第二弹!
2014/05/20 DOTA
[10:21]2018DOTA2国际邀请赛寻真——Winstrike
2018/08/11 DOTA
python使用cookielib库示例分享
2014/03/03 Python
通过Python 获取Android设备信息的轻量级框架
2017/12/18 Python
python判断字符串以什么结尾的实例方法
2020/09/18 Python
如何利用input事件来监听移动端的输入
2016/04/15 HTML / CSS
Linux常见面试题
2013/03/18 面试题
建筑经济管理专业求职信分享
2014/01/06 职场文书
差生评语大全
2014/05/04 职场文书
建国大业观后感
2015/06/01 职场文书
2016幼儿园教师节新闻稿
2015/11/25 职场文书
python实战之90行代码写个猜数字游戏
2021/04/22 Python
Python selenium的这三种等待方式一定要会!
2021/06/10 Python
Win Server2016远程桌面如何允许多用户同时登录
2022/06/10 Servers