TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
Python内置的字符串处理函数详细整理(覆盖日常所用)
Aug 19 Python
Python中asyncore的用法实例
Sep 29 Python
Python内置的HTTP协议服务器SimpleHTTPServer使用指南
Mar 30 Python
python文件的md5加密方法
Apr 06 Python
详解Python 2.6 升级至 Python 2.7 的实践心得
Apr 27 Python
Python定时器实例代码
Nov 01 Python
Numpy中矩阵matrix读取一列的方法及数组和矩阵的相互转换实例
Jul 02 Python
Python3爬虫爬取百姓网列表并保存为json功能示例【基于request、lxml和json模块】
Dec 05 Python
使用Python+wxpy 找出微信里把你删除的好友实例
Feb 21 Python
python 数据分析实现长宽格式的转换
May 18 Python
Python的collections模块真的很好用
Mar 01 Python
LyScript实现绕过反调试保护的示例详解
Aug 14 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
php 缓存函数代码
2008/08/27 PHP
关于ob_get_contents(),ob_end_clean(),ob_start(),的具体用法详解
2013/06/24 PHP
详谈PHP面向对象中常用的关键字和魔术方法
2017/02/04 PHP
php实现等比例压缩图片
2018/07/26 PHP
js程序中美元符号$是什么
2008/06/05 Javascript
js removeChild 障眼法 可能出现的错误
2009/10/06 Javascript
Javascript学习笔记1 数据类型
2010/01/11 Javascript
利用jQuery 实现GridView异步排序、分页的代码
2010/02/06 Javascript
Extjs NumberField后面加单位实现思路
2013/07/30 Javascript
jQuery鼠标事件汇总
2015/08/30 Javascript
CKEditor无法验证的解决方案(js验证+jQuery Validate验证)
2016/05/09 Javascript
JavaScript实现经典排序算法之插入排序
2016/12/28 Javascript
koa2 从入门到精通(小结)
2019/07/23 Javascript
TypeScript高级用法的知识点汇总
2019/12/17 Javascript
koa2的中间件功能及应用示例
2020/03/05 Javascript
详解element-ui 表单校验 Rules 配置 常用黑科技
2020/07/11 Javascript
基于VSCode调试网页JavaScript代码过程详解
2020/07/20 Javascript
[02:27]刀塔重生降临
2015/10/14 DOTA
分享Python字符串关键点
2015/12/13 Python
用不到50行的Python代码构建最小的区块链
2017/11/16 Python
Python实现批量压缩图片
2018/01/25 Python
Python实现全排列的打印
2018/08/18 Python
使用Python如何测试InnoDB与MyISAM的读写性能
2018/09/18 Python
TensorFlow查看输入节点和输出节点名称方式
2020/01/04 Python
解决Django提交表单报错:CSRF token missing or incorrect的问题
2020/03/13 Python
python实现扑克牌交互式界面发牌程序
2020/04/22 Python
python中strip(),lstrip(),rstrip()函数的使用讲解
2020/11/17 Python
介绍一下Java的安全机制
2012/06/28 面试题
前处理班长职位说明书
2014/03/01 职场文书
遗嘱继承公证书
2014/04/09 职场文书
学习型党组织心得体会
2014/09/12 职场文书
2014年变电站工作总结
2014/12/19 职场文书
用人单位的规章制度,怎样制定才是有效的?
2019/07/09 职场文书
Golang标准库syscall详解(什么是系统调用)
2021/05/25 Golang
Python实现简繁体转换
2021/06/07 Python
一文彻底理解js原生语法prototype,__proto__和constructor
2021/10/24 Javascript