tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)


Posted in Python onApril 22, 2020

网上关于tensorflow模型文件ckpt格式转pb文件的帖子很多,本人几乎尝试了所有方法,最后终于成功了,现总结如下。方法无外乎下面两种:

  • 使用tensorflow.python.tools.freeze_graph.freeze_graph
  • 使用graph_util.convert_variables_to_constants

1、tensorflow模型的文件解读

使用tensorflow训练好的模型会自动保存为四个文件,如下

tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)

checkpoint:记录近几次训练好的模型结果(名称)。

xxx.data-00000-of-00001: 模型的所有变量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型训练好参数和其他值。

xxx.index :模型的元数据,二进制或者其他格式,不可直接查看 。是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和一些辅助数据等。

xxx.meta:模型的meta数据 ,二进制或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息,通俗地讲就是神经网络的网络结构。

2、最常见的ckpt转pb文件的方法

2、ckpt转pb文件(freeze_graph.freeze_graph)

此种方法尝试成功,虽然不知道输出节点名,但是只要模型代码还在就可以操作,直接上代码。

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from model import network # network是你们自己定义的模型结构(代码结构)
# egs:
# def network(input):
# return tf.layers.softmax(input)
 
model_path = "model.ckpt-0000" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
 
def main():
 tf.reset_default_graph()
 # 设置输入网络的数据维度,根据训练时的模型输入数据的维度自行修改
 input_node = tf.placeholder(tf.float32, shape=(None, None, 200)) 
 output_node = network(input_node) # 神经网络的输出
 # 设置输出数据类型(特别注意,这里必须要跟输出网络参数的数据格式保持一致,不然会导致模型预测  精度或者预测能力的丢失)以及重新定义输出节点的名字(这样在后面保存pb文件以及之后使用pb文件时直接使用重新定义的节点名字即可)
 flow = tf.cast(output_node , tf.float16, 'the_outputs') 
 saver = tf.train.Saver()
 with tf.Session() as sess:
 saver.restore(sess, model_path)
 #保存模型图(结构),为一个json文件
 tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
 #将模型参数与模型图结合,并保存为pb文件
 freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'the_outputs','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")
 print("done")
if __name__ == '__main__':
 main()

2、ckpt转pb文件(graph_util.convert_variables_to_constants)

没有成功,因为不知道输出节点的名字,使用该方法保存后的pb文件只有几十k,无法使用,写在这里主要是为了总结。直接上代码,代码里面没有的库(函数),按提示自行import。

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess=sess,
  input_graph_def=input_graph_def,# 等于:sess.graph_def
  output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
  f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)

参考链接:

到此这篇关于tensorflow模型文件(ckpt)转pb文件(不知道输出节点名)的文章就介绍到这了,更多相关tensorflow ckpt转pb文件内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
使用Python标准库中的wave模块绘制乐谱的简单教程
Mar 30 Python
wxpython中Textctrl回车事件无效的解决方法
Jul 21 Python
Python 读写文件和file对象的方法(推荐)
Sep 12 Python
python 计算文件的md5值实例
Jan 13 Python
python实现kNN算法
Dec 20 Python
Python2和Python3.6环境解决共存问题
Nov 09 Python
python通过配置文件共享全局变量的实例
Jan 11 Python
详解python-图像处理(映射变换)
Mar 22 Python
对django中foreignkey的简单使用详解
Jul 28 Python
Python closure闭包解释及其注意点详解
Aug 28 Python
Pycharm操作Git及GitHub的步骤详解
Oct 27 Python
python-地图可视化组件folium的操作
Dec 14 Python
有趣的Python图片制作之如何用QQ好友头像拼接出里昂
Apr 22 #Python
python模拟斗地主发牌
Apr 22 #Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
Apr 22 #Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 #Python
matlab中二维插值函数interp2的使用详解
Apr 22 #Python
python 一维二维插值实例
Apr 22 #Python
Numpy一维线性插值函数的用法
Apr 22 #Python
You might like
PHP中PDO基础教程 入门级
2011/09/04 PHP
php数组的概述及分类与声明代码演示
2013/02/26 PHP
PHP识别二维码的方法(php-zbarcode安装与使用)
2016/07/07 PHP
PHP简单实现上一页下一页功能示例
2016/09/14 PHP
javascript中的location用法简单介绍
2007/03/07 Javascript
利用javascript实现一些常用软件的下载导航
2009/08/03 Javascript
ie支持function.bind()方法实现代码
2012/12/27 Javascript
button没写type=button会导致点击时提交
2014/03/06 Javascript
js同源策略详解
2015/05/21 Javascript
jQuery隐藏和显示效果实现
2016/04/06 Javascript
JavaScript获取select中text值的方法
2017/02/13 Javascript
基于Vue中点击组件外关闭组件的实现方法
2018/03/06 Javascript
layui 关闭open弹出框 刷新table表格页面的方法
2019/09/16 Javascript
layer页面跳转,获取html子节点元素的值方法
2019/09/27 Javascript
原生Vue 实现右键菜单组件功能
2019/12/16 Javascript
Vue 修改网站图标的方法
2020/12/31 Vue.js
Python pass 语句使用示例
2014/03/11 Python
pycharm 使用心得(六)进行简单的数据库管理
2014/06/06 Python
Python2.x版本中cmp()方法的使用教程
2015/05/14 Python
python 捕获 shell/bash 脚本的输出结果实例
2017/01/04 Python
详解Python匿名函数(lambda函数)
2019/04/19 Python
Python3.5常见内置方法参数用法实例详解
2019/04/29 Python
Python 一键获取百度网盘提取码的方法
2019/08/01 Python
Python的bit_length函数来二进制的位数方法
2019/08/27 Python
wxPython修改文本框颜色过程解析
2020/02/14 Python
采用专利算法搜索最廉价的机票:CheapAir
2016/09/10 全球购物
电大学习个人自我评价范文
2013/10/04 职场文书
工商学院毕业生自荐信
2013/11/12 职场文书
2014年党务公开实施方案
2014/02/27 职场文书
学校安全责任书范本
2014/07/23 职场文书
酒店总经理岗位职责范本
2014/08/08 职场文书
出差报告格式模板
2014/11/06 职场文书
2015年前台接待工作总结
2015/05/04 职场文书
退伍军人感言
2015/08/01 职场文书
2019大学生社会实践报告汇总
2019/08/16 职场文书
创业计划书之川味火锅店
2019/09/02 职场文书