关于Tensorflow 模型持久化详解


Posted in Python onFebruary 12, 2020

当我们使用 tensorflow 训练神经网络的时候,模型持久化对于我们的训练有很重要的作用。

如果我们的神经网络比较复杂,训练数据比较多,那么我们的模型训练就会耗时很长,如果在训练过程中出现某些不可预计的错误,导致我们的训练意外终止,那么我们将会前功尽弃。为了避免这个问题,我们就可以通过模型持久化(保存为CKPT格式)来暂存我们训练过程中的临时数据。

如果我们训练的模型需要提供给用户做离线的预测,那么我们只需要前向传播的过程,只需得到预测值就可以了,这个时候我们就可以通过模型持久化(保存为PB格式)只保存前向传播中需要的变量并将变量的值固定下来,这个时候只需用户提供一个输入,我们就可以通过模型得到一个输出给用户。

保存为 CKPT 格式的模型

定义运算过程

声明并得到一个 Saver

通过 Saver.save 保存模型

# coding=UTF-8 支持中文编码格式
import tensorflow as tf
import shutil
import os.path

MODEL_DIR = "model/ckpt"
MODEL_NAME = "model.ckpt"

# if os.path.exists(MODEL_DIR): 删除目录
#   shutil.rmtree(MODEL_DIR)
if not tf.gfile.Exists(MODEL_DIR): #创建目录
  tf.gfile.MakeDirs(MODEL_DIR)

#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #输入占位符,并指定名字,后续模型读取可能会用的
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.greater(_y, 50, name="predictions") #输出节点名字,后续模型读取会用到,比50大返回true,否则返回false

init = tf.global_variables_initializer()
saver = tf.train.Saver() #声明saver用于保存模型

with tf.Session() as sess:
  sess.run(init)
  print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}) #输入一个数据测试一下
  saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME)) #模型保存
  print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node)) #得到当前图有几个操作节点

for op in tf.get_default_graph().get_operations(): #打印模型节点信息
  print (op.name, op.values())

运行后生成的文件如下:

关于Tensorflow 模型持久化详解

checkpoint : 记录目录下所有模型文件列表
ckpt.data : 保存模型中每个变量的取值
ckpt.meta : 保存整个计算图的结构

保存为 PB 格式模型

定义运算过程
通过 get_default_graph().as_graph_def() 得到当前图的计算节点信息
通过 graph_util.convert_variables_to_constants 将相关节点的values固定
通过 tf.gfile.GFile 进行模型持久化

# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util


# MODEL_DIR = "model/pb"
# MODEL_NAME = "addmodel.pb"

# if os.path.exists(MODEL_DIR): 删除目录
#   shutil.rmtree(MODEL_DIR)
#
# if not tf.gfile.Exists(MODEL_DIR): #创建目录
#   tf.gfile.MakeDirs(MODEL_DIR)

output_graph = "model/pb/add_model.pb"

#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,否则返回false
predictions = tf.add(_y, 10,name="predictions") #做一个加法运算

init = tf.global_variables_initializer()

with tf.Session() as sess:
  sess.run(init)
  print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
  graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,通过这个部分就可以完成重输入层到输出层的计算过程

  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
    sess,
    graph_def,
    ["predictions"] #需要保存节点的名字
  )
  with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
    f.write(output_graph_def.SerializeToString()) # 序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node))
  print (predictions)

# for op in tf.get_default_graph().get_operations(): 打印模型节点信息
#   print (op.name)

*GraphDef:这个属性记录了tensorflow计算图上节点的信息。

关于Tensorflow 模型持久化详解

add_model.pb : 里面保存了重输入层到输出层这个计算过程的计算图和相关变量的值,我们得到这个模型后传入一个输入,既可以得到一个预估的输出值

CKPT 转换成 PB格式

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化

# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util

MODEL_DIR = "model/pb"
MODEL_NAME = "frozen_model.pb"

if not tf.gfile.Exists(MODEL_DIR): #创建目录
  tf.gfile.MakeDirs(MODEL_DIR)

def freeze_graph(model_folder):
  checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
  input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
  output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路径

  output_node_names = "predictions" #原模型输出操作节点的名字
  saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.

  graph = tf.get_default_graph() #获得默认的图
  input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图

  with tf.Session() as sess:
    saver.restore(sess, input_checkpoint) #恢复图并得到数据

    print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操作节点的名字

    output_graph_def = graph_util.convert_variables_to_constants( #模型持久化,将变量值固定
      sess,
      input_graph_def,
      output_node_names.split(",") #如果有多个输出节点,以逗号隔开
    )
    with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
      f.write(output_graph_def.SerializeToString()) #序列化输出
    print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

    for op in graph.get_operations():
      print(op.name, op.values())

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析,help是提示符,type是输入的类型,
  # 这里运行程序时需要带上模型ckpt的路径,不然会报 error: too few arguments
  aggs = parser.parse_args()
  freeze_graph(aggs.model_folder)
  # freeze_graph("model/ckpt") #模型目录

以上这篇关于Tensorflow 模型持久化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python argv用法详解
Jan 08 Python
Python时间的精准正则匹配方法分析
Aug 17 Python
简述Python2与Python3的不同点
Jan 21 Python
python编程使用selenium模拟登陆淘宝实例代码
Jan 25 Python
python3实现windows下同名进程监控
Jun 21 Python
Django中的文件的上传的几种方式
Jul 23 Python
Python3.5面向对象与继承图文实例详解
Apr 24 Python
python调用函数、类和文件操作简单实例总结
Nov 29 Python
Python如何输出警告信息
Jul 30 Python
Python爬虫入门教程01之爬取豆瓣Top电影
Jan 24 Python
Python的代理类实现,控制访问和修改属性的权限你都了解吗
Mar 21 Python
在python中读取和写入CSV文件详情
Jun 28 Python
Python qrcode 生成一个二维码的实例详解
Feb 12 #Python
python标准库sys和OS的函数使用方法与实例详解
Feb 12 #Python
python标准库os库的函数介绍
Feb 12 #Python
Tensorflow 1.0之后模型文件、权重数值的读取方式
Feb 12 #Python
Python django框架开发发布会签到系统(web开发)
Feb 12 #Python
Python计算公交发车时间的完整代码
Feb 12 #Python
详解Django3中直接添加Websockets方式
Feb 12 #Python
You might like
PHP提取中文首字母
2008/04/09 PHP
简单示例AJAX结合PHP代码实现登录效果代码
2008/07/25 PHP
php支付宝接口用法分析
2015/01/04 PHP
php生成毫秒时间戳的实例讲解
2017/09/22 PHP
基于Jquery的动态添加控件并取值的实现代码
2010/09/24 Javascript
javascript 单例/单体模式(Singleton)
2011/04/07 Javascript
JavaScript表达式:URL 协议介绍
2013/03/10 Javascript
JS制作简单的三级联动
2015/03/18 Javascript
JS实现复制内容到剪贴板功能兼容所有浏览器(推荐)
2016/06/17 Javascript
JS实现六边形3D拖拽翻转效果的方法
2016/09/11 Javascript
javascript中call,apply,bind函数用法示例
2016/12/19 Javascript
JavaScript与Java正则表达式写法的区别介绍
2017/08/15 Javascript
JS中使用textPath实现线条上的文字
2017/12/25 Javascript
AngularJS基于http请求实现下载php生成的excel文件功能示例
2018/01/23 Javascript
web3.js增加eth.getRawTransactionByHash(txhash)方法步骤
2018/03/15 Javascript
VUE注册全局组件和局部组件过程解析
2019/10/10 Javascript
使用Vue-cli3.0创建的项目 如何发布npm包
2019/10/10 Javascript
Vue 中如何将函数作为 props 传递给组件的实现代码
2020/05/12 Javascript
elementUI同一页面展示多个Dialog的实现
2020/11/19 Javascript
[06:21]完美世界亚洲区首席发行官竺琦TI3采访
2013/08/26 DOTA
[40:19]2018完美盛典CS.GO表演赛
2018/12/17 DOTA
python中将函数赋值给变量时需要注意的一些问题
2017/08/18 Python
python 对多个csv文件分别进行处理的方法
2019/01/07 Python
使用python快速在局域网内搭建http传输文件服务的方法
2019/11/14 Python
Python3.6安装卸载、执行命令、执行py文件的方法详解
2020/02/20 Python
Python爬虫JSON及JSONPath运行原理详解
2020/06/04 Python
Django自定义YamlField实现过程解析
2020/11/11 Python
viagogo波兰票务平台:演唱会、体育比赛、戏剧门票
2018/04/23 全球购物
日本著名化妆品零售网站:Cosme Land
2019/03/01 全球购物
毕业生找工作推荐信
2013/11/21 职场文书
预备党员综合考察材料
2014/05/31 职场文书
社区护士演讲稿
2014/08/27 职场文书
银行员工犯错检讨书
2014/09/16 职场文书
产品质量保证书范本
2015/02/27 职场文书
看古人们是如何赞美老师的?
2019/07/08 职场文书
MySQL下使用Inplace和Online方式创建索引的教程
2021/05/26 MySQL