tensorflow的ckpt及pb模型持久化方式及转化详解


Posted in Python onFebruary 12, 2020

使用tensorflow训练模型的时候,模型持久化对我们来说非常重要。

如果我们的模型比较复杂,需要的数据比较多,那么在模型的训练时间会耗时很长。如果在训练过程中出现了模型不可预期的错误,导致训练意外终止,那么我们将会前功尽弃。为了解决这一问题,我们可以使用模型持久化(保存为ckpt文件格式)来保存我们在训练过程中的临时数据。、

如果我们训练出的模型需要提供给用户做离线预测,那么我们只需要完成前向传播过程。这个时候我们就可以使用模型持久化(保存为pb文件格式)来只保存前向传播过程中的变量并将变量固定下来,这时候用户只需要提供一个输入即可得到前向传播的预测结果。

ckpt和pb持久化方式的区别在于ckpt文件将模型结构与模型权重分离保存,便于训练过程;pb文件则是graph_def的序列化文件,便于发布和离线预测。官方提供freeze_grpah.py脚本来将ckpt文件转为pb文件。

CKPT模型持久化

首先定义前向传播过程;

声明并得到一个Saver;

使用Saver.save()保存模型;

# coding=UTF-8 支持中文编码格式
import tensorflow as tf
import shutil
import os.path
 
MODEL_DIR = "/home/zheng/PycharmProjects/ckptLoad/Models/"
MODEL_NAME = "model.ckpt"
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #输入占位符,并指定名字,后续模型读取可能会用的
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions") #输出节点名字,后续模型读取会用到,比50大返回true,否则返回false
 
init = tf.global_variables_initializer()
saver = tf.train.Saver() #声明saver用于保存模型
 
with tf.Session() as sess:
 sess.run(init)
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}) #输入一个数据测试一下
 saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME)) #模型保存
 print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node)) #得到当前图有几个操作节点

predictions : [ 101.]
28 ops in the final graph.

注:代码含义请参考注释,需要注意的是可以自定义模型保存的路径

ckpt模型持久化使用起来非常简单,只需要我们声明一个tf.train.Saver,然后调用save()函数,将会话模型保存到指定的目录。执行代码结果,会在我们指定模型目录下出现4个文件

tensorflow的ckpt及pb模型持久化方式及转化详解

checkpoint : 记录目录下所有模型文件列表
ckpt.data : 保存模型中每个变量的取值
ckpt.meta : 保存整个计算图的结构

ckpt模型加载

# -*- coding: utf-8 -*-)
import tensorflow as tf
from numpy.random import RandomState
 
# 定义训练数据batch的大小
batch_size = 8
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #输入占位符,并指定名字,后续模型读取可能会用的
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions") #输出节点名字,后续模型读取会用到,比50大返回true,否则返回false
 
#saver=tf.train.Saver()
# creare a session,创建一个会话来运行TensorFlow程序
with tf.Session() as sess:
 
 saver = tf.train.import_meta_graph('/home/zheng/Models/model/model.meta')
 saver.restore(sess, tf.train.latest_checkpoint('/home/zheng/Models/model'))
 #saver.restore(sess, tf.train.latest_checkpoint('/home/zheng/Models/model'))
 # 初始化变量
 sess.run(tf.global_variables_initializer())
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})

代码结果,可以看到运行结果一样

predictions : [ 101.]

PB模型持久化

定义运算过程

通过 get_default_graph().as_graph_def() 得到当前图的计算节点信息

通过 graph_util.convert_variables_to_constants 将相关节点的values固定

通过 tf.gfile.GFile 进行模型持久化

# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util
 
MODEL_DIR = "/home/zheng/PycharmProjects/pbLoad/Models/"
MODEL_NAME = "model"
 
 
#output_graph = "model/pb/add_model.pb"
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions")
init = tf.global_variables_initializer()
 
with tf.Session() as sess:
 sess.run(init)
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
 graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,
              #通过这个部分就可以完成重输入层到
              #输出层的计算过程
 
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess,
  graph_def,
  ["predictions"] #需要保存节点的名字
 )
 with tf.gfile.GFile(os.path.join(MODEL_DIR,MODEL_NAME), "wb") as f: # 保存模型
  f.write(output_graph_def.SerializeToString()) # 序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node))
 print (predictions)
 
# for op in tf.get_default_graph().get_operations(): 打印模型节点信息
#  print (op.name)

结果输出

predictions : [ 101.]
Converted 2 variables to const ops.
9 ops in the final graph.
Tensor("predictions:0", shape=(1,), dtype=float32)

tensorflow的ckpt及pb模型持久化方式及转化详解

并在指定目录下生成pb文件模型,保存了从输入层到输出层这个计算过程的计算图和相关变量的值,我们得到这个模型后传入一个输入,既可以得到一个预估的输出值

pb模型文件加载

# -*- coding: utf-8 -*-)
from tensorflow.python.platform import gfile
import tensorflow as tf
from numpy.random import RandomState
 
sess = tf.Session()
with gfile.FastGFile('./Models/model', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') # 导入计算图
 
# 需要有一个初始化的过程
sess.run(tf.global_variables_initializer())
# 需要先复原变量
sess.run('W1:0')
sess.run('B1:0')
# 输入
input_x = sess.graph.get_tensor_by_name('input_holder:0')
#input_y = sess.graph.get_tensor_by_name('y-input:0')
op = sess.graph.get_tensor_by_name('predictions:0')
ret = sess.run(op, feed_dict={input_x:[10]})
print(ret)

输出结果

[ 101.]

我们可以看到结果一致。

ckpt格式转pb格式

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化

# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util
 
MODEL_DIR = "/home/zheng/PycharmProjects/ckptToPb/model/"
MODEL_NAME = "frozen_model"
 
def freeze_graph(model_folder):
 checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路径
 
 output_node_names = "predictions" #原模型输出操作节点的名字
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.
 
 graph = tf.get_default_graph() #获得默认的图
 input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
 
  print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操作节点的名字
 
  output_graph_def = graph_util.convert_variables_to_constants( #模型持久化,将变量值固定
   sess,
   input_graph_def,
   output_node_names.split(",") #如果有多个输出节点,以逗号隔开
  )
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 
if __name__ == '__main__':
 #parser = argparse.ArgumentParser()
 #parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析,help是提示符,type是输入的类型,
 # 这里运行程序时需要带上模型ckpt的路径,不然会报 error: too few arguments
 #aggs = parser.parse_args()
 #freeze_graph(aggs.model_folder)
 freeze_graph("/home/zheng/PycharmProjects/ckptLoad/Models/") #模型目录

注意改变ckpt模型目录及pb文件保存目录 。

tensorflow的ckpt及pb模型持久化方式及转化详解

运行结果为

predictions : [ 101.]
Converted 2 variables to const ops.
9 ops in the final graph.

总结:cpkt文件格式将模型保存为4个文件,pb文件格式为一个。ckpt模型持久化方式将图结构与权重参数分开保存,多了模型更多的细节,适合模型训练阶段;而pb持久化方式完成了从输入到输出的前向传播,完成了端到端的形式,更是个离线使用。

以上这篇tensorflow的ckpt及pb模型持久化方式及转化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3之模块psutil系统性能信息使用
May 30 Python
Python 字符串转换为整形和浮点类型的方法
Jul 17 Python
python中将\\uxxxx转换为Unicode字符串的方法
Sep 06 Python
Python3 串口接收与发送16进制数据包的实例
Jun 12 Python
Python简单实现区域生长方式
Jan 16 Python
用 Python 制作地球仪的方法
Apr 24 Python
基于python连接oracle导并出数据文件
Apr 28 Python
Python logging模块写入中文出现乱码
May 21 Python
五分钟带你搞懂python 迭代器与生成器
Aug 30 Python
python 绘制场景热力图的示例
Sep 23 Python
详解python的super()的作用和原理
Oct 29 Python
python爬取youtube视频的示例代码
Mar 03 Python
关于Tensorflow 模型持久化详解
Feb 12 #Python
Python qrcode 生成一个二维码的实例详解
Feb 12 #Python
python标准库sys和OS的函数使用方法与实例详解
Feb 12 #Python
python标准库os库的函数介绍
Feb 12 #Python
Tensorflow 1.0之后模型文件、权重数值的读取方式
Feb 12 #Python
Python django框架开发发布会签到系统(web开发)
Feb 12 #Python
Python计算公交发车时间的完整代码
Feb 12 #Python
You might like
收音机怀古---春雷3P7图片欣赏
2021/03/02 无线电
利用PHP实现与ASP Banner组件相似的类
2006/10/09 PHP
开发大型 PHP 项目的方法
2007/01/02 PHP
php后台多用户权限组思路与实现程序代码分享
2012/02/13 PHP
PHP如何实现跨域
2016/05/30 PHP
asp.net+jquery滚动滚动条加载数据的下拉控件
2010/06/25 Javascript
JQuery实现当鼠标停留在某区域3秒后自动执行
2014/09/09 Javascript
JS合并数组的几种方法及优劣比较
2014/09/19 Javascript
jQuery给动态添加的元素绑定事件的方法
2015/03/09 Javascript
jQuery scrollFix滚动定位插件
2015/04/01 Javascript
深入分析下javascript中的[]()+!
2015/07/07 Javascript
jquery解析json格式数据的方法(对象、字符串)
2015/11/24 Javascript
Jquery-1.9.1源码分析系列(十一)之DOM操作
2015/11/25 Javascript
Javascript基础_嵌入图像的简单实现
2016/06/14 Javascript
vue.js绑定事件监听器示例【基于v-on事件绑定】
2018/07/07 Javascript
移动端图片上传旋转、压缩问题的方法
2018/10/16 Javascript
ES6入门教程之let、const的使用方法
2019/04/13 Javascript
nodejs中实现用户注册路由功能
2019/05/20 NodeJs
深入了解JavaScript 私有化
2019/05/30 Javascript
js代码实现轮播图
2020/05/04 Javascript
JS禁用右键、禁用Ctrl+u、禁用Ctrl+s、禁用F12的实现代码
2020/12/01 Javascript
[53:21]2014 DOTA2国际邀请赛中国区预选赛5.21 DT VS LGD-CDEC
2014/05/22 DOTA
Python编程实现的简单神经网络算法示例
2018/01/26 Python
python对excel文档去重及求和的实例
2018/04/18 Python
PyCharm安装第三方库如Requests的图文教程
2018/05/18 Python
Python 的字典(Dict)是如何存储的
2019/07/05 Python
Python根据服务获取端口号的方法
2019/09/25 Python
pytorch实现用CNN和LSTM对文本进行分类方式
2020/01/08 Python
浅谈TensorFlow中读取图像数据的三种方式
2020/06/30 Python
详解基于Scrapy的IP代理池搭建
2020/09/29 Python
微信浏览器取消缓存的方法
2015/03/28 HTML / CSS
中国综合网上购物商城:苏宁易购
2016/08/09 全球购物
办公室文秘自我评价
2013/09/21 职场文书
自荐信包含哪些内容
2013/10/30 职场文书
最新会计专业求职信范文
2014/01/28 职场文书
小学校长开学致辞
2015/07/29 职场文书