TensorFlow:将ckpt文件固化成pb文件教程


Posted in Python onFebruary 11, 2020

本文是将yolo3目标检测框架训练出来的ckpt文件固化成pb文件,主要利用了GitHub上的该项目。

为什么要最终生成pb文件呢?简单来说就是直接通过tf.saver保存行程的ckpt文件其变量数据和图是分开的。我们知道TensorFlow是先画图,然后通过placeholde往图里面喂数据。这种解耦形式存在的方法对以后的迁移学习以及对程序进行微小的改动提供了极大的便利性。但是对于训练好,以后不再改变的话这种存在就不再需要。一方面,ckpt文件储存的数据都是变量,既然我们不再改动,就应当让其变成常量,直接‘烧'到图里面。另一方面,对于线上的模型,我们一般是通过C++或者C语言编写的程序进行调用。所以一般模型最终形式都是应该写成pb文件的形式。

由于这次的程序直接从GitHub上下载后改动较小就能够运行,也就是自己写了很少一部分程序。因此进行调试的时候还出现了以前根本没有注意的一些小问题,同时发现自己对TensorFlow还需要更加详细的去研读。

首先对程序进行保存的时候,利用 saver = tf.train.Saver(), saver.save(sess,checkpoint_path,global_step=global_step)对训练的数据进行保存,保存格式为ckpt。但是在恢复的时候一直提示有问题,(其恢复语句为:saver = tf.train.Saver(), saver.restore(sess,ckpt_path),其中,ckpt_path是保存ckpt的文件夹路径)。出现问题的原因我估计是因为我是按照每50个epoch进行保存,而不是让其进行固定次数的batch进行保存,这种固定batch次数的保存系统会自动保存最近5次的ckpt文件(该方法的ckpt_path=tf.train,latest_checkpoint('ckpt/')进行回复)。那么如何将利用epoch的次数进行保存呢(这种保存不是近5次的保存,而是每进行一次保存就会留下当时保存的ckpt,而那种按照batch的会在第n次保存,会将n-5次的删除,n>5)。

我们可以利用:ckpt = tf.train.get_checkpoint_state(ckpt_path),获取最新的ckptpoint文件,然后利用saver.restore(sess,ckpt.checkpoint_path)进行恢复。当然为了安全起见,应该对ckpt和ckpt.checkpoint_path进行判断是否存在后,再进行恢复语句的调用,建议打开ckptpoint看一下,里面记录的最近五次的model的路径,一目了然。即:

saver = tf.train.Saver()
  ckpt = tf.train.get_checkpoint_state(model_path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

对于固化网络,网上有很多的介绍。之所以再介绍,还是由于是用了别人的网络而不是自己的网络遇到的坑。在固化时候我们需要知道输出tensor的名字,而再恢复的时候我们需要知道placeholder的名字。但是,如果网络复杂或者别人的网络命名比较复杂,或者name=,根本就没有自己命名而用的系统自定义的,这样捋起来还是比较费劲的。当时在网上查找的一些方法,像打印整个网络变量的方法(先不管输出的网路名称,甚至随便起一个名字,先固化好pb文件,然后对pb文件进行读取,最后打印操作的名字:

graph = tf.get_default_graph()
  input_graph_def = graph.as_graph_def()
 
  output_graph_def = graph_util.convert_variables_to_constants(
    sess,
    input_graph_def,
    ['cls_score/cls_score', 'cls_prob'] # We split on comma for convenience
  )
  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print ('开始打印节点名字')
  for op in graph.get_operations():
    print(op.name)
  print("%d ops in the final graph." % len(output_graph_def.node))

代码一

这样尽然也能打印出来(尽管输出名字是随便命名的)。但是打印出来的是所有的节点的名字,简直不要太多。这样找的话,一方面可能找不对,另一方面也太费事。

那么怎么办?答案简单的让我也很无语。其实,对ckpt进行数据恢复的时候,直接打印输出的tensor名字就可以。比如说在saver以及placeholder定义的时候:output = model.yolo_inference(images, config.num_anchors / 3, config.num_classes, is_training),我们在后面跟一句:print output,从打印出来的信息即可查看。placeholder的查看方法同样如此。

对网络进行固化:

代码:

input_image_shape = tf.placeholder(dtype = tf.int32, shape = (2,))
  input_image = tf.placeholder(shape = [None, 416, 416, 3], dtype = tf.float32)
  predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
  boxes, scores, classes = predictor.predict(input_image, input_image_shape)
  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  saver = tf.train.Saver()
  ckpt = tf.train.get_checkpoint_state(model_path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
 
  # 采用meta 结构加载,不需要知道网络结构
  # saver = tf.train.import_meta_graph(model_path, clear_devices=True) 
  # 这里的model_path是model.ckpt.meta文件的全路径
  # ckpt_model_path 是保存模型的文件夹路径
  # saver.restore(sess, tf.train.latest_checkpoint(ckpt_model_path))
 
  graph = tf.get_default_graph()
  input_graph_def = graph.as_graph_def()
  output_graph_def = graph_util.convert_variables_to_constants(
    sess,
    input_graph_def,
    ['concat_11','concat_12','concat_13'] # We split on comma for convenience
  )
  # # Finally we serialize and dump the output graph to the filesystem
  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())

由于固化的时候是需要先恢复ckpt网络的,所以还是在restore前写了placeholder和输出tensor的定义(需要注点意的是,我们保存的ckpt文件是训练阶段的graph和变量等,其inference输出和最终predict的输出的Tensor不一样,因此predict与inference的输出相比,还包括了一些后处理,比如说nms等等,只有这些后处理也是TensorFlow框架内的方法写的,才能使最终形成的pb文件能够做到输入一张图片,直接输出最终结果。因此,对于目标检测任务,把后处理任务也交由TensorFlow内的api来实现,可免去夸平台读取pb文件后仍然需要重新进行后处理等相关程序的编写带来的不必要麻烦)。然后结合保存变量的那个文件(ckpt),将变量恢复到inference过程所需的变量数据(predict包括inference和eval两个过程,训练过程只有inference和loss过程参与,而预测过程多了一个后处理eval过程,eval过程无变量。这样在生成pb文件的时候也把后处理eval固化进去。喂给网络数据,即可得到输出tensor。

由于有读者在此问到了还是没有弄明白'concat_11','concat_12','concat_13'是如何得来的,我在这里就在详细说一下:

是这样的,在我们恢复网络的时候肯定需要知道saver这个对象的,在这里介绍两种方法生成这个对象的方法。

一:

saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)

其中meta_graph_location就是保存模型时的.meta文件的路径。保存后有四个文件(checkpoint、.index、.data-00000-of-00001和.meta文件)。.meta文件就是整个TensorFlow的结构图。

二:

saver = tf.train.Saver()

本文采用的是第二种方法(上面已经有详细的代码),由于这种方法得到的saver对象,他不知道具体图是什么样的,因此在恢复前我有用如下代码

predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
boxes, scores, classes = predictor.predict(input_image, input_image_shape)

把整个结构又加载了一遍。如果采用第一种方法,是不需要在重写这两行代码的。

我们要的就是 boxes, scores, classes这三个tensor的结果,并且想知道他们三个tensor的名字。你直接利用print(boxes, scores, classes)打印出来这三个tensor就会出来这三个tensor具体信息(包括名字,和shape,dtype等)。这个只是利用第二种方法得到saver对象,然后恢复ckpt文件,不涉及到固化pb文件问题。固化pb文件是需要知道这三个tensor的名字,所以需要打印看一下。

如果说,我只拿到了保存后的四个文件(checkpoint、.index、.data-00000-of-00001和.meta文件),其相应用代码写成的结构图不清楚,比如说利用这两行代码:

predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
boxes, scores, classes = predictor.predict(input_image, input_image_shape)

画出的结构图是什么样的,我不知道。那么,想要知道具体的placehold和输出tensor的名字,那只能通过代码一中,打印出所有的OP操作节点,然后进行人工遍历了。

读取pb文件:

代码:

def pb_detect(image_path, pb_model_path):
 
  os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_index
  image = Image.open(image_path)
  resize_image = letterbox_image(image, (416, 416))
  image_data = np.array(resize_image, dtype = np.float32)
  image_data /= 255.
  image_data = np.expand_dims(image_data, axis = 0)
  with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(pb_model_path, "rb") as f:
      output_graph_def.ParseFromString(f.read())
      tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      input_image_tensor = sess.graph.get_tensor_by_name("Placeholder_1:0")
      input_image_tensor_shape = sess.graph.get_tensor_by_name("Placeholder:0")
      # 定义输出的张量名称
      #output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
      boxes = sess.graph.get_tensor_by_name("concat_11:0")
      scores = sess.graph.get_tensor_by_name("concat_12:0")
      classes = sess.graph.get_tensor_by_name("concat_13:0")
      # 读取测试图片
      # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字(需要在名字后面加:0),不是操作节点的名字
      out_boxes, out_scores, out_classes= sess.run([boxes,scores,classes],
              feed_dict={
                input_image_tensor: image_data,
                input_image_tensor_shape: [image.size[1], image.size[0]]
      })

可以看到读取pb文件只需要比恢复ckpt文件容易的多,直接将placeholder的名字获取到,将数据输入恢复的网络,以及读取输出即可。

小记:

有可能是TensorFlow版本更新或者其他原因,在后来工作中加载pb文件是报错了:

ValueError: Fetch argument <tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024) dtype=float32> cannot be interpreted as a Tensor. (tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024), dtype=float32) is not an element of this graph.)

将上面读取pb文件的代码with tf.Graph().as_default():改成

global graph
graph = tf.get_default_graph()
with graph.as_default():

以上这篇TensorFlow:将ckpt文件固化成pb文件教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现linux服务器批量修改密码并生成execl
Apr 22 Python
Python实现设置windows桌面壁纸代码分享
Mar 28 Python
Python字符串中查找子串小技巧
Apr 10 Python
python动态网页批量爬取
Feb 14 Python
Python 多核并行计算的示例代码
Nov 07 Python
python多进程使用及线程池的使用方法代码详解
Oct 24 Python
程序员写Python时的5个坏习惯,你有几条?
Nov 26 Python
python占位符输入方式实例
May 27 Python
python 批量添加的button 使用同一点击事件的方法
Jul 17 Python
使用Python制作表情包实现换脸功能
Jul 19 Python
Python实现打包成库供别的模块调用
Jul 13 Python
详解Python中string模块除去Str还剩下什么
Nov 30 Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
Python版中国省市经纬度
Feb 11 #Python
You might like
phpMyAdmin 链接表的附加功能尚未激活的问题
2010/08/01 PHP
PHP计数器的实现代码
2013/06/08 PHP
学习php设计模式 php实现模板方法模式
2015/12/08 PHP
PHP实现表单提交数据的验证处理功能【防SQL注入和XSS攻击等】
2017/07/21 PHP
thinkphp5.1 框架导入/导出excel文件操作示例
2020/05/25 PHP
详解阿里云视频直播PHP-SDK接入教程
2020/07/09 PHP
JQuery对id中含有特殊字符的转义处理示例
2013/09/06 Javascript
使用js的replace()方法查找字符示例代码
2013/10/28 Javascript
动态加载jquery库的方法
2014/02/12 Javascript
JQuery boxy插件在IE中边角图片不显示问题的解决
2015/05/20 Javascript
js实现YouKu的漂亮搜索框效果
2015/08/19 Javascript
体验jQuery和AngularJS的不同点及AngularJS的迷人之处
2016/02/02 Javascript
jquery插件uploadify多图上传功能实现代码
2016/08/12 Javascript
深入探究angular2 UI组件之primeNG用法
2017/07/26 Javascript
在knockoutjs 上自己实现的flux(实例讲解)
2017/12/18 Javascript
Vue动态组件与异步组件实例详解
2019/02/23 Javascript
JS实现的检验身份证格式并输出出生日期,年龄,性别,出生地示例
2019/05/17 Javascript
vue如何搭建多页面多系统应用
2020/06/17 Javascript
解决vue安装less报错Failed to compile with 1 errors的问题
2020/10/22 Javascript
vue 判断两个时间插件结束时间必选大于开始时间的代码
2020/11/04 Javascript
vue项目实现减少app.js和vender.js的体积操作
2020/11/12 Javascript
[06:23]2014DOTA2西雅图国际邀请赛 小组赛7月12日TOPPLAY
2014/07/12 DOTA
最基础的Python的socket编程入门教程
2015/04/23 Python
python不换行之end=与逗号的意思及用途
2017/11/21 Python
Django中url的反向查询的方法
2018/03/14 Python
python利用跳板机ssh远程连接redis的方法
2019/02/19 Python
如何使用Python实现斐波那契数列
2019/07/02 Python
python开发之anaconda以及win7下安装gensim的方法
2019/07/05 Python
python-docx文件定位读取过程(尝试替换)
2020/02/13 Python
基于python实现地址和经纬度转换
2020/05/19 Python
浅析Python面向对象编程
2020/07/10 Python
html5 Canvas画图教程(11)—使用lineTo/arc/bezierCurveTo画椭圆形
2013/01/09 HTML / CSS
Uber Eats台湾:寻找附近提供送餐服务的餐厅
2018/05/07 全球购物
北美最大的手工艺品零售商之一:Michaels Stores
2019/02/27 全球购物
收银员岗位职责
2015/02/03 职场文书
导游词之京东大峡谷旅游区
2019/10/29 职场文书