Tensorflow之Saver的用法详解


Posted in Python onApril 23, 2018

Saver的用法

1. Saver的背景介绍

我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。

只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。

为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

2. Saver的实例

下面以一个例子来讲述如何使用Saver类 

import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b))
  1. isTrain:用来区分训练阶段和测试阶段,True表示训练,False表示测试
  2. train_steps:表示训练的次数,例子中使用100
  3. checkpoint_steps:表示训练多少次保存一下checkpoints,例子中使用50
  4. checkpoint_dir:表示checkpoints文件的保存路径,例子中使用当前路径

2.1 训练阶段

使用Saver.save()方法保存模型:

  1. sess:表示当前会话,当前会话记录了当前的变量值
  2. checkpoint_dir + 'model.ckpt':表示存储的文件名
  3. global_step:表示当前是第几步

训练完成后,当前目录底下会多出5个文件。

Tensorflow之Saver的用法详解

打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。

Tensorflow之Saver的用法详解

2.1测试阶段

测试阶段使用saver.restore()方法恢复变量:

sess:表示当前会话,之前保存的结果将被加载入这个会话

ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

运行结果如下图所示,加载了之前训练的参数w和b的结果

Tensorflow之Saver的用法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则匹配抓取豆瓣电影链接和评论代码分享
Dec 27 Python
python使用rabbitmq实现网络爬虫示例
Feb 20 Python
python登录pop3邮件服务器接收邮件的方法
Apr 30 Python
Python中property属性实例解析
Feb 10 Python
python实现list由于numpy array的转换
Apr 04 Python
想学python 这5本书籍你必看!
Dec 11 Python
对Python中画图时候的线类型详解
Jul 07 Python
Python学习笔记之Zip和Enumerate用法实例分析
Aug 14 Python
用python爬取历史天气数据的方法示例
Dec 30 Python
使用Python 自动生成 Word 文档的教程
Feb 13 Python
python Socket网络编程实现C/S模式和P2P
Jun 22 Python
Python join()函数原理及使用方法
Nov 14 Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 #Python
python遍历一个目录,输出所有的文件名的实例
Apr 23 #Python
python 获取文件下所有文件或目录os.walk()的实例
Apr 23 #Python
You might like
分享常见的几种页面静态化的方法
2015/01/08 PHP
为百度UE编辑器上传图片添加水印功能
2015/04/16 PHP
php while循环控制的简单实例
2016/05/30 PHP
ThinkPHP Where 条件中常用表达式示例(详解)
2017/03/31 PHP
YII2框架中ActiveDataProvider与GridView的配合使用操作示例
2020/03/18 PHP
PHP二维数组分页2种实现方法解析
2020/07/09 PHP
Jquery中getJSON在asp.net中的使用说明
2011/03/10 Javascript
jQuery.event兼容各浏览器的event详细解析
2013/12/18 Javascript
jquery下div 的resize事件示例代码
2014/03/09 Javascript
基于Jquery制作图片文字排版预览效果附源码下载
2015/11/18 Javascript
noty ? jQuery通知插件全面解析
2016/05/18 Javascript
jQuery增加、删除及修改select option的方法
2016/08/19 Javascript
微信小程序 倒计时组件实现代码
2016/10/24 Javascript
原生js实现网页顶部自动下拉/收缩广告效果
2017/01/20 Javascript
Angular JS 生成动态二维码的方法
2017/02/23 Javascript
JavaScript实现两个select下拉框选项左移右移
2017/03/09 Javascript
JS实现指定区域的全屏显示功能示例
2019/04/25 Javascript
js数组中去除重复值的几种方法
2020/08/03 Javascript
原生js实现下拉框选择组件
2021/01/20 Javascript
NestJs使用Mongoose对MongoDB操作的方法
2021/02/22 Javascript
python错误:AttributeError: 'module' object has no attribute 'setdefaultencoding'问题的解决方法
2014/08/22 Python
Python中使用不同编码读写txt文件详解
2015/05/28 Python
python检查字符串是否是正确ISBN的方法
2015/07/11 Python
浅谈python中set使用
2016/06/30 Python
Python内置模块logging用法实例分析
2018/02/12 Python
python实现SOM算法
2018/02/23 Python
python机器学习之贝叶斯分类
2018/03/26 Python
python 实现将list转成字符串,中间用空格隔开
2019/12/25 Python
DjangoWeb使用Datatable进行后端分页的实现
2020/05/18 Python
用Python匹配HTML tag的时候,<.*>和<.*?>有什么区别
2012/11/04 面试题
平面设计的岗位职责
2013/11/08 职场文书
党员批评与自我批评
2014/02/12 职场文书
团队执行力培训心得体会
2015/08/15 职场文书
教您怎么制定西餐厅运营方案 ?
2019/07/05 职场文书
七年级话题作文之执着
2019/11/19 职场文书
Mysql关于数据库是否应该使用外键约束详解说明
2021/10/24 MySQL