TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
在Django中管理Users和Permissions以及Groups的方法
Jul 23 Python
Python的消息队列包SnakeMQ使用初探
Jun 29 Python
Python之py2exe打包工具详解
Jun 14 Python
利用Python如何制作好玩的GIF动图详解
Jul 11 Python
Python 判断文件或目录是否存在的实例代码
Jul 19 Python
解决python中画图时x,y轴名称出现中文乱码的问题
Jan 29 Python
Python面向对象程序设计多继承和多态用法示例
Apr 08 Python
pyqt5 实现在别的窗口弹出进度条
Jun 18 Python
简单了解Python3里的一些新特性
Jul 13 Python
flask开启多线程的具体方法
Aug 02 Python
python通用数据库操作工具 pydbclib的使用简介
Dec 21 Python
python利用xpath爬取网上数据并存储到django模型中
Feb 26 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
239军机修复记
2021/03/02 无线电
从一个不错的留言本弄的mysql数据库操作类
2007/09/02 PHP
PHP性能优化工具篇Benchmark类调试执行时间
2011/12/06 PHP
PHP定时执行计划任务的多种方法小结
2011/12/19 PHP
使用php判断网页是否gzip压缩
2013/06/25 PHP
解决ThinkPHP下使用上传插件Uploadify浏览器firefox报302错误的方法
2015/12/18 PHP
thinkPHP查询方式小结
2016/01/09 PHP
动态表单验证的操作方法和TP框架里面的ajax表单验证
2017/07/19 PHP
laravel 解决多库下的DB::transaction()事务失效问题
2019/10/21 PHP
Laravel实现ORM带条件搜索分页
2019/10/24 PHP
PHP 范围解析操作符(::)用法分析【访问静态成员和类常量】
2020/04/14 PHP
Firefox div高度自适应
2009/04/28 Javascript
用Javascript实现Windows任务管理器的代码
2012/03/27 Javascript
javascript游戏开发之《三国志曹操传》零部件开发(三)情景对话中仿打字机输出文字
2013/01/23 Javascript
JSONP之我见
2015/03/24 Javascript
JQuery中属性过滤选择器用法实例分析
2015/05/18 Javascript
跟我学习javascript的prototype使用注意事项
2015/11/17 Javascript
怎么引入(调用)一个JS文件
2016/05/26 Javascript
浅谈jQuery 中的事件冒泡和阻止默认行为
2016/05/28 Javascript
node.js 动态执行脚本
2016/06/02 Javascript
详解前端构建工具gulpjs的使用介绍及技巧
2017/01/19 Javascript
浅谈Webpack多页应用HMR卡住问题
2019/04/24 Javascript
python实现颜色rgb和hex相互转换的函数
2015/03/19 Python
Python3中使用urllib的方法详解(header,代理,超时,认证,异常处理)
2016/09/21 Python
python pandas dataframe 按列或者按行合并的方法
2018/04/12 Python
总结python中pass的作用
2019/02/27 Python
python中id函数运行方式
2020/07/03 Python
matplotlib 三维图表绘制方法简介
2020/09/20 Python
Python爬虫模拟登陆哔哩哔哩(bilibili)并突破点选验证码功能
2020/12/21 Python
CSS3+DIV实现漂亮的动画彩色标签
2016/06/16 HTML / CSS
加拿大花店:1800Flowers.ca
2016/11/16 全球购物
工业设计专业推荐信
2013/10/29 职场文书
《三顾茅庐》教学反思
2014/04/10 职场文书
计划生育证明书写要求
2014/09/17 职场文书
2014年信息技术工作总结
2014/12/16 职场文书
SONY AN-LP1 短波有源天线放大器
2021/04/22 无线电