tensorflow的ckpt及pb模型持久化方式及转化详解


Posted in Python onFebruary 12, 2020

使用tensorflow训练模型的时候,模型持久化对我们来说非常重要。

如果我们的模型比较复杂,需要的数据比较多,那么在模型的训练时间会耗时很长。如果在训练过程中出现了模型不可预期的错误,导致训练意外终止,那么我们将会前功尽弃。为了解决这一问题,我们可以使用模型持久化(保存为ckpt文件格式)来保存我们在训练过程中的临时数据。、

如果我们训练出的模型需要提供给用户做离线预测,那么我们只需要完成前向传播过程。这个时候我们就可以使用模型持久化(保存为pb文件格式)来只保存前向传播过程中的变量并将变量固定下来,这时候用户只需要提供一个输入即可得到前向传播的预测结果。

ckpt和pb持久化方式的区别在于ckpt文件将模型结构与模型权重分离保存,便于训练过程;pb文件则是graph_def的序列化文件,便于发布和离线预测。官方提供freeze_grpah.py脚本来将ckpt文件转为pb文件。

CKPT模型持久化

首先定义前向传播过程;

声明并得到一个Saver;

使用Saver.save()保存模型;

# coding=UTF-8 支持中文编码格式
import tensorflow as tf
import shutil
import os.path
 
MODEL_DIR = "/home/zheng/PycharmProjects/ckptLoad/Models/"
MODEL_NAME = "model.ckpt"
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #输入占位符,并指定名字,后续模型读取可能会用的
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions") #输出节点名字,后续模型读取会用到,比50大返回true,否则返回false
 
init = tf.global_variables_initializer()
saver = tf.train.Saver() #声明saver用于保存模型
 
with tf.Session() as sess:
 sess.run(init)
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}) #输入一个数据测试一下
 saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME)) #模型保存
 print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node)) #得到当前图有几个操作节点

predictions : [ 101.]
28 ops in the final graph.

注:代码含义请参考注释,需要注意的是可以自定义模型保存的路径

ckpt模型持久化使用起来非常简单,只需要我们声明一个tf.train.Saver,然后调用save()函数,将会话模型保存到指定的目录。执行代码结果,会在我们指定模型目录下出现4个文件

tensorflow的ckpt及pb模型持久化方式及转化详解

checkpoint : 记录目录下所有模型文件列表
ckpt.data : 保存模型中每个变量的取值
ckpt.meta : 保存整个计算图的结构

ckpt模型加载

# -*- coding: utf-8 -*-)
import tensorflow as tf
from numpy.random import RandomState
 
# 定义训练数据batch的大小
batch_size = 8
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #输入占位符,并指定名字,后续模型读取可能会用的
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions") #输出节点名字,后续模型读取会用到,比50大返回true,否则返回false
 
#saver=tf.train.Saver()
# creare a session,创建一个会话来运行TensorFlow程序
with tf.Session() as sess:
 
 saver = tf.train.import_meta_graph('/home/zheng/Models/model/model.meta')
 saver.restore(sess, tf.train.latest_checkpoint('/home/zheng/Models/model'))
 #saver.restore(sess, tf.train.latest_checkpoint('/home/zheng/Models/model'))
 # 初始化变量
 sess.run(tf.global_variables_initializer())
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})

代码结果,可以看到运行结果一样

predictions : [ 101.]

PB模型持久化

定义运算过程

通过 get_default_graph().as_graph_def() 得到当前图的计算节点信息

通过 graph_util.convert_variables_to_constants 将相关节点的values固定

通过 tf.gfile.GFile 进行模型持久化

# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util
 
MODEL_DIR = "/home/zheng/PycharmProjects/pbLoad/Models/"
MODEL_NAME = "model"
 
 
#output_graph = "model/pb/add_model.pb"
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions")
init = tf.global_variables_initializer()
 
with tf.Session() as sess:
 sess.run(init)
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
 graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,
              #通过这个部分就可以完成重输入层到
              #输出层的计算过程
 
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess,
  graph_def,
  ["predictions"] #需要保存节点的名字
 )
 with tf.gfile.GFile(os.path.join(MODEL_DIR,MODEL_NAME), "wb") as f: # 保存模型
  f.write(output_graph_def.SerializeToString()) # 序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node))
 print (predictions)
 
# for op in tf.get_default_graph().get_operations(): 打印模型节点信息
#  print (op.name)

结果输出

predictions : [ 101.]
Converted 2 variables to const ops.
9 ops in the final graph.
Tensor("predictions:0", shape=(1,), dtype=float32)

tensorflow的ckpt及pb模型持久化方式及转化详解

并在指定目录下生成pb文件模型,保存了从输入层到输出层这个计算过程的计算图和相关变量的值,我们得到这个模型后传入一个输入,既可以得到一个预估的输出值

pb模型文件加载

# -*- coding: utf-8 -*-)
from tensorflow.python.platform import gfile
import tensorflow as tf
from numpy.random import RandomState
 
sess = tf.Session()
with gfile.FastGFile('./Models/model', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') # 导入计算图
 
# 需要有一个初始化的过程
sess.run(tf.global_variables_initializer())
# 需要先复原变量
sess.run('W1:0')
sess.run('B1:0')
# 输入
input_x = sess.graph.get_tensor_by_name('input_holder:0')
#input_y = sess.graph.get_tensor_by_name('y-input:0')
op = sess.graph.get_tensor_by_name('predictions:0')
ret = sess.run(op, feed_dict={input_x:[10]})
print(ret)

输出结果

[ 101.]

我们可以看到结果一致。

ckpt格式转pb格式

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化

# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util
 
MODEL_DIR = "/home/zheng/PycharmProjects/ckptToPb/model/"
MODEL_NAME = "frozen_model"
 
def freeze_graph(model_folder):
 checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路径
 
 output_node_names = "predictions" #原模型输出操作节点的名字
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.
 
 graph = tf.get_default_graph() #获得默认的图
 input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
 
  print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操作节点的名字
 
  output_graph_def = graph_util.convert_variables_to_constants( #模型持久化,将变量值固定
   sess,
   input_graph_def,
   output_node_names.split(",") #如果有多个输出节点,以逗号隔开
  )
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 
if __name__ == '__main__':
 #parser = argparse.ArgumentParser()
 #parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析,help是提示符,type是输入的类型,
 # 这里运行程序时需要带上模型ckpt的路径,不然会报 error: too few arguments
 #aggs = parser.parse_args()
 #freeze_graph(aggs.model_folder)
 freeze_graph("/home/zheng/PycharmProjects/ckptLoad/Models/") #模型目录

注意改变ckpt模型目录及pb文件保存目录 。

tensorflow的ckpt及pb模型持久化方式及转化详解

运行结果为

predictions : [ 101.]
Converted 2 variables to const ops.
9 ops in the final graph.

总结:cpkt文件格式将模型保存为4个文件,pb文件格式为一个。ckpt模型持久化方式将图结构与权重参数分开保存,多了模型更多的细节,适合模型训练阶段;而pb持久化方式完成了从输入到输出的前向传播,完成了端到端的形式,更是个离线使用。

以上这篇tensorflow的ckpt及pb模型持久化方式及转化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python按行读取文件的简单实现方法
Jun 22 Python
Python实现抢购IPhone手机
Feb 07 Python
Python八大常见排序算法定义、实现及时间消耗效率分析
Apr 27 Python
python3模块smtplib实现发送邮件功能
May 22 Python
解决Python安装后pip不能用的问题
Jun 12 Python
浅析python3中的os.path.dirname(__file__)的使用
Aug 30 Python
Python3.4解释器用法简单示例
Mar 22 Python
python实现基于朴素贝叶斯的垃圾分类算法
Jul 09 Python
django连接mysql数据库及建表操作实例详解
Dec 10 Python
Python assert关键字原理及实例解析
Dec 13 Python
Python实现鼠标自动在屏幕上随机移动功能
Mar 14 Python
在django中查询获取数据,get, filter,all(),values()操作
Aug 09 Python
关于Tensorflow 模型持久化详解
Feb 12 #Python
Python qrcode 生成一个二维码的实例详解
Feb 12 #Python
python标准库sys和OS的函数使用方法与实例详解
Feb 12 #Python
python标准库os库的函数介绍
Feb 12 #Python
Tensorflow 1.0之后模型文件、权重数值的读取方式
Feb 12 #Python
Python django框架开发发布会签到系统(web开发)
Feb 12 #Python
Python计算公交发车时间的完整代码
Feb 12 #Python
You might like
WinXP + Apache +PHP5 + MySQL + phpMyAdmin安装全功略
2006/07/09 PHP
php代码收集表单内容并写入文件的代码
2012/01/29 PHP
PHP中如何防止外部恶意提交调用ajax接口
2016/04/11 PHP
PHP中header函数的用法及其注意事项详解
2016/06/13 PHP
php实现表单提交上传文件功能
2018/05/28 PHP
php微信公众号开发之图片回复
2018/10/20 PHP
js 实现打印网页中定义的部分内容的代码
2010/04/01 Javascript
读jQuery之十 事件模块概述
2011/06/27 Javascript
用jquery实现点击栏目背景色改变
2012/12/10 Javascript
JS控制文本框textarea输入字数限制的方法
2013/06/17 Javascript
jQuery焦点图切换简易插件制作过程全纪录
2014/08/27 Javascript
js判断浏览器是否支持严格模式的方法
2016/10/04 Javascript
js继承实现方法详解
2016/12/16 Javascript
javascript字体颜色控件的开发 JS实现字体控制
2017/11/27 Javascript
利用Node.js检测端口是否被占用的方法
2017/12/07 Javascript
全新打包工具parcel零配置vue开发脚手架
2018/01/11 Javascript
jQuery中图片展示插件highslide.js的简单dom
2018/04/22 jQuery
angularjs结合html5实现拖拽功能
2018/06/25 Javascript
vue+vuex+json-seiver实现数据展示+分页功能
2019/04/11 Javascript
[02:40]DOTA2超级联赛专访430 从小就爱玩对抗性游戏
2013/06/18 DOTA
Python使用PyGreSQL操作PostgreSQL数据库教程
2014/07/30 Python
Python读写配置文件的方法
2015/06/03 Python
Python3随机漫步生成数据并绘制
2018/08/27 Python
利用nohup来开启python文件的方法
2019/01/14 Python
python sqlite的Row对象操作示例
2019/09/11 Python
python字符串判断密码强弱
2020/03/18 Python
python打包多类型文件的操作方法
2020/09/21 Python
为什么使用接口?
2014/08/13 面试题
写一个在SQL Server创建表的SQL语句
2012/03/10 面试题
信息系统专业个人求职信范文
2013/12/07 职场文书
大学毕业生求职自荐信
2014/02/20 职场文书
节水标语大全
2014/06/11 职场文书
学雷锋团日活动总结
2015/05/06 职场文书
导游词之五台山
2019/10/11 职场文书
JavaScript实现音乐播放器
2022/08/14 Javascript
CSS 实现磨砂玻璃(毛玻璃)效果样式
2023/05/21 HTML / CSS