浅谈Tensorflow模型的保存与恢复加载


Posted in Python onApril 26, 2018

近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载。

总结一下Tensorflow常用的模型保存方式。

保存checkpoint模型文件(.ckpt)

首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型。

模型保存

使用tf.train.Saver()来保存模型文件非常方便,下面是一个简单的例子:

import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))

程序生成并保存四个文件(在版本0.11之前只会生成三个文件:checkpoint, model.ckpt, model.ckpt.meta)

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. model.ckpt.data-00000-of-00001 网络权重信息
  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

以上是tf.train.Saver().save()的基本用法,save()方法还有很多可配置的参数:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)

加上global_step参数代表在每1000次迭代后保存模型,会在模型文件后加上"-1000",model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000-of-00001

每1000次迭代保存一次模型,但是模型的结构信息文件不会变,就只用1000次迭代时保存一下,不用相应的每1000次保存一次,所以当我们不需要保存meta文件时,可以加上write_meta_graph=False参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

如果想每两小时保存一次模型,并且只保存最新的4个模型,可以加上使用max_to_keep(默认值为5,如果想每训练一个epoch就保存一次,可以将其设置为None或0,但是没啥用不推荐), keep_checkpoint_every_n_hours参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)

同时在tf.train.Saver()类中,如果我们不指定任何信息,则会保存所有的参数信息,我们也可以指定部分想要保存的内容,例如只保存x, y参数(可传入参数list或dict):

tf.train.Saver([x, y]).save(sess, ckpt_file_path)

ps. 在模型训练过程中需要在保存后拿到的变量或参数名属性name不能丢,不然模型还原后不能通过get_tensor_by_name()获取。

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:

import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)

首先还原模型结构,然后还原变量(参数)信息,最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。

针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。

同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python学习笔记(二)基础语法
Jun 06 Python
Python基于time模块求程序运行时间的方法
Sep 18 Python
对numpy 数组和矩阵的乘法的进一步理解
Apr 04 Python
解决Python pandas plot输出图形中显示中文乱码问题
Dec 12 Python
Python3.7 新特性之dataclass装饰器
May 27 Python
实例详解Python模块decimal
Jun 26 Python
pygame实现俄罗斯方块游戏(AI篇1)
Oct 29 Python
基于梯度爆炸的解决方法:clip gradient
Feb 04 Python
python如何支持并发方法详解
Jul 25 Python
python如何操作mysql
Aug 17 Python
基于python实现百度语音识别和图灵对话
Nov 02 Python
Opencv中cv2.floodFill算法的使用
Jun 18 Python
Python实现爬取百度贴吧帖子所有楼层图片的爬虫示例
Apr 26 #Python
Python实现的计算器功能示例
Apr 26 #Python
python email smtplib模块发送邮件代码实例
Apr 26 #Python
Python利用正则表达式实现计算器算法思路解析
Apr 25 #Python
Python实现随机生成手机号及正则验证手机号的方法
Apr 25 #Python
Python实现按中文排序的方法示例
Apr 25 #Python
Python实现的基于优先等级分配糖果问题算法示例
Apr 25 #Python
You might like
浅析ThinkPHP中的pathinfo模式和URL重写
2014/01/06 PHP
php中获取主机名、协议及IP地址的方法
2014/11/18 PHP
在WordPress中获取数据库字段内容和添加主题设置菜单
2016/01/11 PHP
java微信开发之上传下载多媒体文件
2016/06/24 PHP
php简单构造json多维数组的方法示例
2017/06/08 PHP
Javascript 定时器调用传递参数的方法
2009/11/12 Javascript
NodeJs基本语法和类型
2015/02/13 NodeJs
百度UEditor编辑器如何关闭抓取远程图片功能
2015/03/03 Javascript
JavaScript实现的简单拖拽效果
2015/06/01 Javascript
高效Web开发的10个jQuery代码片段
2016/07/22 Javascript
JS实现焦点图轮播效果的方法详解
2016/12/19 Javascript
谈谈为什么你的 JavaScript 代码如此冗长
2019/01/30 Javascript
详解Vue之父子组件传值
2019/04/01 Javascript
vue-mugen-scroll组件实现pc端滚动刷新
2019/08/16 Javascript
聊聊鉴权那些事(推荐)
2019/08/22 Javascript
node.js使用mongoose操作数据库实现购物车的增、删、改、查功能示例
2019/12/23 Javascript
JavaScript适配器模式原理与用法实例详解
2020/03/09 Javascript
jenkins自动构建发布vue项目的方法步骤
2021/01/04 Vue.js
[01:29:31]VP VS VG Supermajor小组赛胜者组第二轮 BO3第一场 6.2
2018/06/03 DOTA
[01:35:53]完美世界DOTA2联赛PWL S3 Magma vs GXR 第二场 12.13
2020/12/17 DOTA
Python装饰器使用示例及实际应用例子
2015/03/06 Python
Python实现自定义函数的5种常见形式分析
2018/06/16 Python
利用python循环创建多个文件的方法
2018/10/25 Python
python实现12306登录并保存cookie的方法示例
2019/12/17 Python
手把手教你安装Windows版本的Tensorflow
2020/03/26 Python
Pymysql实现往表中插入数据过程解析
2020/06/02 Python
Madewell美德威尔美国官网:美国休闲服饰品牌
2016/11/25 全球购物
巴西Mr. Cat在线商店:购买包包和鞋子
2019/09/08 全球购物
阿迪达斯中国官网:Adidas中国
2020/12/14 全球购物
群众路线党员个人剖析材料
2014/10/08 职场文书
乐山大佛导游词
2015/02/02 职场文书
承诺保证书格式
2015/02/28 职场文书
PyCharm配置KBEngine快速处理代码提示冲突、配置命令问题
2021/04/03 Python
JAVA API 实用类 String详解
2021/10/05 Java/Android
CentOS安装Nginx并部署vue
2022/04/12 Servers
Sql Server 行数据的某列值想作为字段列显示的方法
2022/04/20 SQL Server