浅谈Tensorflow模型的保存与恢复加载


Posted in Python onApril 26, 2018

近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载。

总结一下Tensorflow常用的模型保存方式。

保存checkpoint模型文件(.ckpt)

首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型。

模型保存

使用tf.train.Saver()来保存模型文件非常方便,下面是一个简单的例子:

import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))

程序生成并保存四个文件(在版本0.11之前只会生成三个文件:checkpoint, model.ckpt, model.ckpt.meta)

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. model.ckpt.data-00000-of-00001 网络权重信息
  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

以上是tf.train.Saver().save()的基本用法,save()方法还有很多可配置的参数:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)

加上global_step参数代表在每1000次迭代后保存模型,会在模型文件后加上"-1000",model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000-of-00001

每1000次迭代保存一次模型,但是模型的结构信息文件不会变,就只用1000次迭代时保存一下,不用相应的每1000次保存一次,所以当我们不需要保存meta文件时,可以加上write_meta_graph=False参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

如果想每两小时保存一次模型,并且只保存最新的4个模型,可以加上使用max_to_keep(默认值为5,如果想每训练一个epoch就保存一次,可以将其设置为None或0,但是没啥用不推荐), keep_checkpoint_every_n_hours参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)

同时在tf.train.Saver()类中,如果我们不指定任何信息,则会保存所有的参数信息,我们也可以指定部分想要保存的内容,例如只保存x, y参数(可传入参数list或dict):

tf.train.Saver([x, y]).save(sess, ckpt_file_path)

ps. 在模型训练过程中需要在保存后拿到的变量或参数名属性name不能丢,不然模型还原后不能通过get_tensor_by_name()获取。

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:

import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)

首先还原模型结构,然后还原变量(参数)信息,最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。

针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。

同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发腾讯微博代码分享
Jan 10 Python
Python入门篇之字典
Oct 17 Python
python getopt详解及简单实例
Dec 30 Python
python操作列表的函数使用代码详解
Dec 28 Python
Python3内置模块pprint让打印比print更美观详解
Jun 02 Python
python打造爬虫代理池过程解析
Aug 15 Python
python数据持久存储 pickle模块的基本使用方法解析
Aug 30 Python
python GUI库图形界面开发之PyQt5滚动条控件QScrollBar详细使用方法与实例
Mar 06 Python
Win10下用Anaconda安装TensorFlow(图文教程)
Jun 18 Python
Python 如何创建一个线程池
Jul 28 Python
Python pip install之SSL异常处理操作
Sep 03 Python
Numpy中np.random.rand()和np.random.randn() 用法和区别详解
Oct 23 Python
Python实现爬取百度贴吧帖子所有楼层图片的爬虫示例
Apr 26 #Python
Python实现的计算器功能示例
Apr 26 #Python
python email smtplib模块发送邮件代码实例
Apr 26 #Python
Python利用正则表达式实现计算器算法思路解析
Apr 25 #Python
Python实现随机生成手机号及正则验证手机号的方法
Apr 25 #Python
Python实现按中文排序的方法示例
Apr 25 #Python
Python实现的基于优先等级分配糖果问题算法示例
Apr 25 #Python
You might like
用php获取远程图片并把它保存到本地的代码
2008/04/07 PHP
PHP实现XML与数据格式进行转换类实例
2015/07/29 PHP
PHP随手笔记整理之PHP脚本和JAVA连接mysql数据库
2015/11/25 PHP
PHP进阶学习之类的自动加载机制原理分析
2019/06/18 PHP
使用jQuery快速解决input中placeholder值在ie中无法支持的问题
2014/01/02 Javascript
Javascript表单特效之十大常用原理性样例代码大总结
2016/07/12 Javascript
JavaScript中对象的不同创建方法
2016/08/12 Javascript
vue.js学习之vue-cli定制脚手架详解
2017/07/02 Javascript
mac上node.js环境的安装测试
2017/07/03 Javascript
关于vue.js发布后路径引用的问题解决
2017/08/15 Javascript
axios全局注册,设置token,以及全局设置url请求网段的方法
2018/09/25 Javascript
小程序实现左右来回滚动字幕效果
2018/12/28 Javascript
关于微信小程序获取小程序码并接受buffer流保存为图片的方法
2019/06/07 Javascript
nuxt配置通过指定IP和端口访问的实现
2020/01/08 Javascript
vue-amap根据地址回显地图并mark的操作
2020/11/03 Javascript
Vue实现点击当前行变色
2020/12/14 Vue.js
[08:08]2014DOTA2国际邀请赛中国区预选赛精彩TOPPLAY
2014/06/25 DOTA
Python采用raw_input读取输入值的方法
2014/08/18 Python
Python多线程编程(一):threading模块综述
2015/04/05 Python
Python读取键盘输入的2种方法
2015/06/16 Python
Python基于分水岭算法解决走迷宫游戏示例
2017/09/26 Python
举例讲解Python常用模块
2019/03/08 Python
基于canvas使用贝塞尔曲线平滑拟合折线段的方法
2018/01/10 HTML / CSS
Square Off美国/加拿大:世界上最聪明的国际象棋棋盘
2018/12/06 全球购物
单位领导证婚词
2014/01/14 职场文书
2013年研究生毕业感言
2014/02/06 职场文书
《燕子专列》教学反思
2014/02/21 职场文书
无私奉献演讲稿
2014/09/04 职场文书
领导班子奢靡之风查摆问题及整改措施
2014/09/27 职场文书
财务整改报告范文
2014/11/05 职场文书
2014年信息中心工作总结
2014/12/17 职场文书
2015年学习部工作总结范文
2015/03/31 职场文书
2015年汽车销售员工作总结
2015/07/24 职场文书
文明礼仪主题班会
2015/08/13 职场文书
2016年幼儿园教师政治学习心得体会
2016/01/23 职场文书
涨工资申请书应该怎么写?
2019/07/08 职场文书