浅谈Tensorflow模型的保存与恢复加载


Posted in Python onApril 26, 2018

近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载。

总结一下Tensorflow常用的模型保存方式。

保存checkpoint模型文件(.ckpt)

首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型。

模型保存

使用tf.train.Saver()来保存模型文件非常方便,下面是一个简单的例子:

import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))

程序生成并保存四个文件(在版本0.11之前只会生成三个文件:checkpoint, model.ckpt, model.ckpt.meta)

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. model.ckpt.data-00000-of-00001 网络权重信息
  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

以上是tf.train.Saver().save()的基本用法,save()方法还有很多可配置的参数:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)

加上global_step参数代表在每1000次迭代后保存模型,会在模型文件后加上"-1000",model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000-of-00001

每1000次迭代保存一次模型,但是模型的结构信息文件不会变,就只用1000次迭代时保存一下,不用相应的每1000次保存一次,所以当我们不需要保存meta文件时,可以加上write_meta_graph=False参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

如果想每两小时保存一次模型,并且只保存最新的4个模型,可以加上使用max_to_keep(默认值为5,如果想每训练一个epoch就保存一次,可以将其设置为None或0,但是没啥用不推荐), keep_checkpoint_every_n_hours参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)

同时在tf.train.Saver()类中,如果我们不指定任何信息,则会保存所有的参数信息,我们也可以指定部分想要保存的内容,例如只保存x, y参数(可传入参数list或dict):

tf.train.Saver([x, y]).save(sess, ckpt_file_path)

ps. 在模型训练过程中需要在保存后拿到的变量或参数名属性name不能丢,不然模型还原后不能通过get_tensor_by_name()获取。

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:

import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)

首先还原模型结构,然后还原变量(参数)信息,最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。

针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。

同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单介绍Python中的JSON模块
Apr 08 Python
python实现用户答题功能
Jan 17 Python
Python把csv数据写入list和字典类型的变量脚本方法
Jun 15 Python
解决python 自动安装缺少模块的问题
Oct 22 Python
pandas 数据归一化以及行删除例程的方法
Nov 10 Python
Python3获取电脑IP、主机名、Mac地址的方法示例
Apr 11 Python
利用python-pypcap抓取带VLAN标签的数据包方法
Jul 23 Python
解决安装python3.7.4报错Can''t connect to HTTPS URL because the SSL module is not available
Jul 31 Python
Python 实现Numpy中找出array中最大值所对应的行和列
Nov 26 Python
关于Python 解决Python3.9 pandas.read_excel(‘xxx.xlsx‘)报错的问题
Nov 28 Python
Python  Asyncio模块实现的生产消费者模型的方法
Mar 01 Python
python中pd.cut()与pd.qcut()的对比及示例
Jun 16 Python
Python实现爬取百度贴吧帖子所有楼层图片的爬虫示例
Apr 26 #Python
Python实现的计算器功能示例
Apr 26 #Python
python email smtplib模块发送邮件代码实例
Apr 26 #Python
Python利用正则表达式实现计算器算法思路解析
Apr 25 #Python
Python实现随机生成手机号及正则验证手机号的方法
Apr 25 #Python
Python实现按中文排序的方法示例
Apr 25 #Python
Python实现的基于优先等级分配糖果问题算法示例
Apr 25 #Python
You might like
PHP调用三种数据库的方法(2)
2006/10/09 PHP
php调用google接口生成二维码示例
2014/04/28 PHP
PHP常用的缓存技术汇总
2014/05/05 PHP
PHP会话控制:Session与Cookie详解
2014/09/27 PHP
Laravel 5框架学习之表单
2015/04/08 PHP
php操作redis缓存方法分享
2015/06/03 PHP
PHP统计目录中文件以及目录中目录大小的方法
2016/01/09 PHP
PHP简单实现遍历目录下特定文件的方法小结
2017/05/22 PHP
JavaScript日历实现代码
2010/09/12 Javascript
lyhucSelect基于Jquery的Select数据联动插件
2011/03/29 Javascript
JS获取各种浏览器窗口大小的方法
2014/01/14 Javascript
Jquery的基本对象转换和文档加载用法实例
2015/02/25 Javascript
深入理解JavaScript系列(39):设计模式之适配器模式详解
2015/03/04 Javascript
nodejs中使用多线程编程的方法实例
2015/03/24 NodeJs
jQuery实现仿腾讯视频列表分页效果的方法
2015/08/07 Javascript
JavaScript实现仿淘宝商品购买数量的增减效果
2016/01/22 Javascript
jQuery验证插件validate使用详解
2016/05/11 Javascript
js实现简单的手风琴效果
2017/02/27 Javascript
各种选择框jQuery的选中方法(实例讲解)
2017/06/27 jQuery
vue实现树形菜单效果
2018/03/19 Javascript
JS秒杀倒计时功能完整实例【使用jQuery3.1.1】
2019/09/03 jQuery
JS如何实现动态添加的元素绑定事件
2019/11/12 Javascript
jQuery使用hide()、toggle()函数实现相机品牌展示隐藏功能
2021/01/29 jQuery
Python struct.unpack
2008/09/06 Python
python实现mysql的单引号字符串过滤方法
2015/11/14 Python
Python实现mysql数据库更新表数据接口的功能
2017/11/19 Python
pycharm 实现显示project 选项卡的方法
2019/01/17 Python
Python socket非阻塞模块应用示例
2019/09/12 Python
css3打造一款漂亮的卡哇伊按钮
2013/03/20 HTML / CSS
Boston Proper官网:美国女装品牌
2017/10/30 全球购物
韩国最大的购物网站:Gmarket
2019/06/20 全球购物
上学迟到的检讨书
2014/01/11 职场文书
项目合作意向书范本
2014/04/01 职场文书
毕业实习证明范本
2015/06/16 职场文书
大学迎新生欢迎词
2015/09/29 职场文书
Java Spring Lifecycle的使用
2022/05/06 Java/Android