Tensorflow之Saver的用法详解


Posted in Python onApril 23, 2018

Saver的用法

1. Saver的背景介绍

我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。

只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。

为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

2. Saver的实例

下面以一个例子来讲述如何使用Saver类 

import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b))
  1. isTrain:用来区分训练阶段和测试阶段,True表示训练,False表示测试
  2. train_steps:表示训练的次数,例子中使用100
  3. checkpoint_steps:表示训练多少次保存一下checkpoints,例子中使用50
  4. checkpoint_dir:表示checkpoints文件的保存路径,例子中使用当前路径

2.1 训练阶段

使用Saver.save()方法保存模型:

  1. sess:表示当前会话,当前会话记录了当前的变量值
  2. checkpoint_dir + 'model.ckpt':表示存储的文件名
  3. global_step:表示当前是第几步

训练完成后,当前目录底下会多出5个文件。

Tensorflow之Saver的用法详解

打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。

Tensorflow之Saver的用法详解

2.1测试阶段

测试阶段使用saver.restore()方法恢复变量:

sess:表示当前会话,之前保存的结果将被加载入这个会话

ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

运行结果如下图所示,加载了之前训练的参数w和b的结果

Tensorflow之Saver的用法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python及PyCharm下载与安装教程
Nov 18 Python
python2.7实现爬虫网页数据
May 25 Python
Python if语句知识点用法总结
Jun 10 Python
python实现自动发送邮件
Jun 20 Python
transform python环境快速配置方法
Sep 27 Python
使用python将时间转换为指定的格式方法
Nov 12 Python
windows下安装Python虚拟环境virtualenvwrapper-win
Jun 14 Python
python编写猜数字小游戏
Oct 06 Python
xadmin使用formfield_for_dbfield函数过滤下拉表单实例
Apr 07 Python
Python爬虫入门有哪些基础知识点
Jun 02 Python
python使用QQ邮箱实现自动发送邮件
Jun 22 Python
如何一键升级Python所有包
Nov 05 Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 #Python
python遍历一个目录,输出所有的文件名的实例
Apr 23 #Python
python 获取文件下所有文件或目录os.walk()的实例
Apr 23 #Python
You might like
消息持续发送的完整例子
2006/10/09 PHP
PHP EOT定界符的使用详解
2008/09/30 PHP
用PHP实现的四则运算表达式计算实现代码
2011/08/02 PHP
php启用sphinx全文搜索的实现方法
2014/12/24 PHP
详解PHP实现执行定时任务
2015/12/21 PHP
php使用preg_match()函数验证ip地址的方法
2017/01/07 PHP
thinkPHP5.0框架模块设计详解
2017/03/18 PHP
浅谈php的TS和NTS的区别
2019/03/13 PHP
详解阿里云视频直播PHP-SDK接入教程
2020/07/09 PHP
滚动经典最新话题[prototype框架]下编写
2006/10/03 Javascript
jquery怎样实现ajax联动框(二)
2013/03/08 Javascript
javascript中如何处理引号编码"
2013/08/15 Javascript
禁止空格提交表单的js代码
2013/11/17 Javascript
JavaScript实现弹出子窗口并传值给父窗口
2014/12/18 Javascript
jQuery学习笔记之jQuery中的$
2015/01/19 Javascript
jQuery通过扩展实现抖动效果的方法
2015/03/11 Javascript
jQuery基于扩展实现的倒计时效果
2016/05/14 Javascript
JavaScript核心语法总结(推荐)
2016/06/02 Javascript
浅谈javascript基础之客户端事件驱动
2016/06/10 Javascript
jQuery Ajax 实现分页 kkpager插件实例代码
2017/08/10 jQuery
Vue代码分割懒加载的实现方法
2017/11/23 Javascript
Vue 组件参数校验与非props特性的方法
2019/02/12 Javascript
[15:58]DOTA2国际邀请赛采访专栏:Tongfu.Sansheng&KingJ,DK.rOtk
2013/08/08 DOTA
[02:21]2018完美盛典章节片——初心
2018/12/17 DOTA
一些Python中的二维数组的操作方法
2015/05/02 Python
Python实现的朴素贝叶斯算法经典示例【测试可用】
2018/06/13 Python
详解python3 + Scrapy爬虫学习之创建项目
2019/04/12 Python
对Tensorflow中Device实例的生成和管理详解
2020/02/04 Python
matplotlib绘制鼠标的十字光标的实现(自定义方式,官方实例)
2021/01/10 Python
HTML5表格_动力节点Java学院整理
2017/07/11 HTML / CSS
W Concept美国:精选全球独立设计师
2017/02/22 全球购物
Veronica Beard官网:在酷、经典和别致之间找到了平衡
2018/01/11 全球购物
DOUGLAS波兰:在线销售香水和化妆品
2020/07/05 全球购物
工程资料员岗位职责
2014/03/10 职场文书
MySQL触发器的使用
2021/05/24 MySQL
MySQL GRANT用户授权的实现
2021/06/18 MySQL