tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)


Posted in Python onApril 22, 2020

网上关于tensorflow模型文件ckpt格式转pb文件的帖子很多,本人几乎尝试了所有方法,最后终于成功了,现总结如下。方法无外乎下面两种:

  • 使用tensorflow.python.tools.freeze_graph.freeze_graph
  • 使用graph_util.convert_variables_to_constants

1、tensorflow模型的文件解读

使用tensorflow训练好的模型会自动保存为四个文件,如下

tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)

checkpoint:记录近几次训练好的模型结果(名称)。

xxx.data-00000-of-00001: 模型的所有变量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型训练好参数和其他值。

xxx.index :模型的元数据,二进制或者其他格式,不可直接查看 。是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和一些辅助数据等。

xxx.meta:模型的meta数据 ,二进制或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息,通俗地讲就是神经网络的网络结构。

2、最常见的ckpt转pb文件的方法

2、ckpt转pb文件(freeze_graph.freeze_graph)

此种方法尝试成功,虽然不知道输出节点名,但是只要模型代码还在就可以操作,直接上代码。

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from model import network # network是你们自己定义的模型结构(代码结构)
# egs:
# def network(input):
# return tf.layers.softmax(input)
 
model_path = "model.ckpt-0000" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
 
def main():
 tf.reset_default_graph()
 # 设置输入网络的数据维度,根据训练时的模型输入数据的维度自行修改
 input_node = tf.placeholder(tf.float32, shape=(None, None, 200)) 
 output_node = network(input_node) # 神经网络的输出
 # 设置输出数据类型(特别注意,这里必须要跟输出网络参数的数据格式保持一致,不然会导致模型预测  精度或者预测能力的丢失)以及重新定义输出节点的名字(这样在后面保存pb文件以及之后使用pb文件时直接使用重新定义的节点名字即可)
 flow = tf.cast(output_node , tf.float16, 'the_outputs') 
 saver = tf.train.Saver()
 with tf.Session() as sess:
 saver.restore(sess, model_path)
 #保存模型图(结构),为一个json文件
 tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
 #将模型参数与模型图结合,并保存为pb文件
 freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'the_outputs','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")
 print("done")
if __name__ == '__main__':
 main()

2、ckpt转pb文件(graph_util.convert_variables_to_constants)

没有成功,因为不知道输出节点的名字,使用该方法保存后的pb文件只有几十k,无法使用,写在这里主要是为了总结。直接上代码,代码里面没有的库(函数),按提示自行import。

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess=sess,
  input_graph_def=input_graph_def,# 等于:sess.graph_def
  output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
  f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)

参考链接:

到此这篇关于tensorflow模型文件(ckpt)转pb文件(不知道输出节点名)的文章就介绍到这了,更多相关tensorflow ckpt转pb文件内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python使用win32com在百度空间插入html元素示例
Feb 20 Python
python3+PyQt5使用数据库表视图
Apr 24 Python
python文件拆分与重组实例
Dec 10 Python
Python单元测试unittest的具体使用示例
Dec 17 Python
详解python实现小波变换的一个简单例子
Jul 18 Python
python实现大战外星人小游戏实例代码
Dec 26 Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 Python
django 扩展user用户字段inlines方式
Mar 30 Python
python读写数据读写csv文件(pandas用法)
Dec 14 Python
Python编程super应用场景及示例解析
Oct 05 Python
Python使用pandas导入xlsx格式的excel文件内容操作代码
Dec 24 Python
有趣的Python图片制作之如何用QQ好友头像拼接出里昂
Apr 22 #Python
python模拟斗地主发牌
Apr 22 #Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
Apr 22 #Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 #Python
matlab中二维插值函数interp2的使用详解
Apr 22 #Python
python 一维二维插值实例
Apr 22 #Python
Numpy一维线性插值函数的用法
Apr 22 #Python
You might like
转生史莱姆:萌王第一次撸串开心到飞起,哥布塔撸串却神似界王神
2018/11/30 日漫
php实现阿拉伯数字和罗马数字相互转换的方法
2015/04/17 PHP
php给图片加文字水印
2015/07/31 PHP
从sohu弄下来的flash中展示图片的代码
2007/04/27 Javascript
jQuery操作select的实例代码
2012/06/14 Javascript
js获取单选框或复选框值及操作
2012/12/18 Javascript
JQuery 获取json数据$.getJSON方法的实例代码
2013/08/02 Javascript
调用DOM对象的focus使文本框获得焦点
2014/02/19 Javascript
js如何实现淡入淡出效果
2020/11/18 Javascript
详解JavaScript中localStorage使用要点
2016/01/13 Javascript
使用JavaScript实现ajax的实例代码
2016/05/11 Javascript
jQuery插件zTree实现清空选中第一个节点所有子节点的方法
2017/03/08 Javascript
微信小程序实现图片轮播及文件上传
2017/04/07 Javascript
jQuery+PHP+Mysql实现抽奖程序
2020/04/12 jQuery
Vue组件化通讯的实例代码
2017/06/23 Javascript
详解RequireJs官方使用教程
2017/10/31 Javascript
动手写一个angular版本的Message组件的方法
2017/12/16 Javascript
微信小程序使用form表单获取输入框数据的实例代码
2018/05/17 Javascript
jQuery无冲突模式详解
2019/01/17 jQuery
vue实现行列转换的一种方法
2019/08/06 Javascript
Vue实现点击当前元素以外的地方隐藏当前元素(实现思路)
2019/12/04 Javascript
解决vue+router路由跳转不起作用的一项原因
2020/07/19 Javascript
利用Python中SocketServer 实现客户端与服务器间非阻塞通信
2016/12/15 Python
pycharm修改界面主题颜色的方法
2019/01/17 Python
python判断文件是否存在,不存在就创建一个的实例
2019/02/18 Python
python实现在函数图像上添加文字和标注的方法
2019/07/08 Python
python定义类self用法实例解析
2020/01/22 Python
Python文字截图识别OCR工具实例解析
2020/03/05 Python
canvas之自定义头像功能实现代码示例
2017/09/29 HTML / CSS
HashMap和Hashtable的区别
2013/05/18 面试题
校园招聘策划书
2014/01/09 职场文书
公司副总经理岗位职责
2014/10/01 职场文书
科学发展观标语
2014/10/08 职场文书
婚礼父母答谢词
2015/01/04 职场文书
超市采购员岗位职责
2015/04/07 职场文书
《秋天的怀念》教学反思
2016/02/17 职场文书