Tensorflow之Saver的用法详解


Posted in Python onApril 23, 2018

Saver的用法

1. Saver的背景介绍

我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。

只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。

为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

2. Saver的实例

下面以一个例子来讲述如何使用Saver类 

import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b))
  1. isTrain:用来区分训练阶段和测试阶段,True表示训练,False表示测试
  2. train_steps:表示训练的次数,例子中使用100
  3. checkpoint_steps:表示训练多少次保存一下checkpoints,例子中使用50
  4. checkpoint_dir:表示checkpoints文件的保存路径,例子中使用当前路径

2.1 训练阶段

使用Saver.save()方法保存模型:

  1. sess:表示当前会话,当前会话记录了当前的变量值
  2. checkpoint_dir + 'model.ckpt':表示存储的文件名
  3. global_step:表示当前是第几步

训练完成后,当前目录底下会多出5个文件。

Tensorflow之Saver的用法详解

打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。

Tensorflow之Saver的用法详解

2.1测试阶段

测试阶段使用saver.restore()方法恢复变量:

sess:表示当前会话,之前保存的结果将被加载入这个会话

ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

运行结果如下图所示,加载了之前训练的参数w和b的结果

Tensorflow之Saver的用法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中字符编码简介、方法及使用建议
Jan 08 Python
python中getaddrinfo()基本用法实例分析
Jun 28 Python
Python注释详解
Jun 01 Python
Python中的字符串替换操作示例
Jun 27 Python
使用Python3 编写简单信用卡管理程序
Dec 21 Python
Python实现通过继承覆盖方法示例
Jul 02 Python
ERLANG和PYTHON互通实现过程详解
Jul 05 Python
Python 3.6 中使用pdfminer解析pdf文件的实现
Sep 25 Python
Pytorch 实现sobel算子的卷积操作详解
Jan 10 Python
python 爬虫 实现增量去重和定时爬取实例
Feb 28 Python
Pycharm配置PyQt5环境的教程
Apr 02 Python
Django分页器的用法你都了解吗
May 26 Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 #Python
python遍历一个目录,输出所有的文件名的实例
Apr 23 #Python
python 获取文件下所有文件或目录os.walk()的实例
Apr 23 #Python
You might like
php代码优化及php相关问题总结
2006/10/09 PHP
pdo中使用参数化查询sql
2011/08/11 PHP
php使用百度翻译api示例分享
2014/01/31 PHP
Laravel框架表单验证操作实例分析
2019/09/30 PHP
详解将数据从Laravel传送到vue的四种方式
2019/10/16 PHP
浅谈laravel框架sql中groupBy之后排序的问题
2019/10/17 PHP
laravel利用中间件做防非法登录和权限控制示例
2019/10/21 PHP
javascript怎么禁用浏览器后退按钮
2014/03/27 Javascript
JavaScript实现广告的关闭与显示效果实例
2015/07/02 Javascript
JS实现Fisheye效果动感放大菜单代码
2015/10/21 Javascript
javascript将中国数字格式转换成欧式数字格式的简单实例
2016/08/02 Javascript
JavaScript实现拖拽元素对齐到网格(每次移动固定距离)
2016/11/30 Javascript
JQuery 动态生成Table表格实例代码
2016/12/02 Javascript
详解使用Vue.Js结合Jquery Ajax加载数据的两种方式
2017/01/10 Javascript
js实现股票实时刷新数据案例
2017/05/14 Javascript
页面间固定参数,通过cookie传值的实现方法
2017/05/31 Javascript
Vue组件化通讯的实例代码
2017/06/23 Javascript
详解vue组件通信的三种方式
2017/06/30 Javascript
python实现随机密码字典生成器示例
2014/04/09 Python
python各种语言间时间的转化实现代码
2016/03/23 Python
python写一个md5解密器示例
2018/02/23 Python
python3+selenium实现qq邮箱登陆并发送邮件功能
2019/01/23 Python
python实现ip代理池功能示例
2019/07/05 Python
Python SQLAlchemy入门教程(基本用法)
2019/11/11 Python
Python 剪绳子的多种思路实现(动态规划和贪心)
2020/02/24 Python
python GUI库图形界面开发之PyQt5单选按钮控件QRadioButton详细使用方法与实例
2020/02/28 Python
从0到1使用python开发一个半自动答题小程序的实现
2020/05/12 Python
CSS3制作文字半透明倒影效果的两种实现方式
2014/08/08 HTML / CSS
瑞典轮胎在线:Tirendo.se
2018/06/21 全球购物
英国钻石公司:British Diamond Company
2020/02/16 全球购物
什么是java序列化,如何实现java序列化
2012/11/14 面试题
年终奖发放方案
2014/06/02 职场文书
安全生产先进个人事迹材料
2014/12/30 职场文书
高校自主招生自荐信2015
2015/03/04 职场文书
css height属性中的calc方法详解
2021/06/03 HTML / CSS
MySQL池化框架学习接池自定义
2022/07/23 MySQL