浅谈Tensorflow模型的保存与恢复加载


Posted in Python onApril 26, 2018

近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载。

总结一下Tensorflow常用的模型保存方式。

保存checkpoint模型文件(.ckpt)

首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型。

模型保存

使用tf.train.Saver()来保存模型文件非常方便,下面是一个简单的例子:

import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))

程序生成并保存四个文件(在版本0.11之前只会生成三个文件:checkpoint, model.ckpt, model.ckpt.meta)

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. model.ckpt.data-00000-of-00001 网络权重信息
  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

以上是tf.train.Saver().save()的基本用法,save()方法还有很多可配置的参数:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)

加上global_step参数代表在每1000次迭代后保存模型,会在模型文件后加上"-1000",model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000-of-00001

每1000次迭代保存一次模型,但是模型的结构信息文件不会变,就只用1000次迭代时保存一下,不用相应的每1000次保存一次,所以当我们不需要保存meta文件时,可以加上write_meta_graph=False参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

如果想每两小时保存一次模型,并且只保存最新的4个模型,可以加上使用max_to_keep(默认值为5,如果想每训练一个epoch就保存一次,可以将其设置为None或0,但是没啥用不推荐), keep_checkpoint_every_n_hours参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)

同时在tf.train.Saver()类中,如果我们不指定任何信息,则会保存所有的参数信息,我们也可以指定部分想要保存的内容,例如只保存x, y参数(可传入参数list或dict):

tf.train.Saver([x, y]).save(sess, ckpt_file_path)

ps. 在模型训练过程中需要在保存后拿到的变量或参数名属性name不能丢,不然模型还原后不能通过get_tensor_by_name()获取。

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:

import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)

首先还原模型结构,然后还原变量(参数)信息,最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。

针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。

同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现上传样本到virustotal并查询扫描信息的方法
Oct 05 Python
在Python的while循环中使用else以及循环嵌套的用法
Oct 14 Python
Python处理JSON数据并生成条形图
Aug 05 Python
利用Python找出序列中出现最多的元素示例代码
Dec 08 Python
pandas数据清洗,排序,索引设置,数据选取方法
May 18 Python
python中类的属性和方法介绍
Nov 27 Python
Python读写文件基础知识点
Jun 10 Python
django+tornado实现实时查看远程日志的方法
Aug 12 Python
Python大数据之网络爬虫的post请求、get请求区别实例分析
Nov 16 Python
浅谈pytorch中的BN层的注意事项
Jun 23 Python
python在一个范围内取随机数的简单实例
Aug 16 Python
Python实现PIL图像处理库绘制国际象棋棋盘
Jul 16 Python
Python实现爬取百度贴吧帖子所有楼层图片的爬虫示例
Apr 26 #Python
Python实现的计算器功能示例
Apr 26 #Python
python email smtplib模块发送邮件代码实例
Apr 26 #Python
Python利用正则表达式实现计算器算法思路解析
Apr 25 #Python
Python实现随机生成手机号及正则验证手机号的方法
Apr 25 #Python
Python实现按中文排序的方法示例
Apr 25 #Python
Python实现的基于优先等级分配糖果问题算法示例
Apr 25 #Python
You might like
PHP 进程锁定问题分析研究
2009/11/24 PHP
php学习笔记 数组的常用函数
2011/06/13 PHP
解析php DOMElement 操作xml 文档的实现代码
2013/05/10 PHP
ThinkPHP的RBAC(基于角色权限控制)深入解析
2013/06/17 PHP
discuz免激活同步登入代码修改方法(discuz同步登录)
2013/12/24 PHP
PHP获取文件扩展名的常用方法小结【五种方式】
2018/04/27 PHP
css transform 3D幻灯片特效实现步骤解读
2013/03/27 Javascript
javascript设计模式之解释器模式详解
2014/06/05 Javascript
JS运动框架之分享侧边栏动画实例
2015/03/03 Javascript
jquery+CSS实现的多级竖向展开树形TRee菜单效果
2015/08/24 Javascript
javascript实现checkbox复选框实例代码
2016/01/10 Javascript
基于javascript实现checkbox复选框实例代码
2016/01/28 Javascript
AngularJS入门教程中SQL实例详解
2016/07/27 Javascript
vue-resource 拦截器(interceptor)的使用详解
2017/07/04 Javascript
vuejs选中当前样式active的实例
2018/08/22 Javascript
Vue-Cli 3.0 中配置高德地图的两种方式
2019/06/19 Javascript
原生javascript运动函数的封装示例【匀速、抛物线、多属性的运动等】
2020/02/23 Javascript
通过实例解析chrome如何在mac环境中安装vue-devtools插件
2020/07/10 Javascript
理解Proxy及使用Proxy实现vue数据双向绑定操作
2020/07/18 Javascript
vue treeselect获取当前选中项的label实例
2020/08/31 Javascript
Python自定义函数的创建、调用和函数的参数详解
2014/03/11 Python
python脚本替换指定行实现步骤
2017/07/11 Python
python学生管理系统开发
2019/01/30 Python
pygame实现非图片按钮效果
2019/10/29 Python
Python读取表格类型文件代码实例
2020/02/17 Python
html5小程序飞入购物车(抛物线绘制运动轨迹点)
2020/10/19 HTML / CSS
社团活动策划书范文
2014/01/09 职场文书
先进集体事迹材料
2014/02/17 职场文书
工程力学专业自荐信范文
2014/03/17 职场文书
学生干部培训方案
2014/06/12 职场文书
公司演讲稿开场白
2014/08/25 职场文书
创先争优演讲稿
2014/09/15 职场文书
2015年勤工助学工作总结
2015/04/29 职场文书
接触艺术对孩子学习思维有益
2019/08/06 职场文书
Python操作CSV格式文件的方法大全
2021/07/15 Python
奥特曼十大神器:奥特手镯在榜,第一是贝利亚的神器
2022/03/18 日漫