TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
Python中optionParser模块的使用方法实例教程
Aug 29 Python
Python xlrd读取excel日期类型的2种方法
Apr 28 Python
Python二分查找详解
Sep 13 Python
运行django项目指定IP和端口的方法
May 14 Python
用Python分析3天破10亿的《我不是药神》到底神在哪?
Jul 12 Python
Python面向对象程序设计中类的定义、实例化、封装及私有变量/方法详解
Feb 28 Python
python实现感知机线性分类模型示例代码
Jun 02 Python
python基于paramiko将文件上传到服务器代码实现
Jul 08 Python
django项目简单调取百度翻译接口的方法
Aug 06 Python
python3 动态模块导入与全局变量使用实例
Dec 22 Python
Python中无限循环需要什么条件
May 27 Python
python 下划线的不同用法
Oct 24 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
析构函数与php的垃圾回收机制详解
2013/10/28 PHP
php查询mysql数据库并将结果保存到数组的方法
2015/03/18 PHP
PHP中创建和验证哈希的简单方法实探
2015/07/06 PHP
extjs fckeditor集成代码
2009/05/10 Javascript
Ext grid 添加右击菜单
2009/11/26 Javascript
用JQuery在网页中实现分隔条功能的代码
2012/08/09 Javascript
设为首页加入收藏兼容360/火狐/谷歌/IE等主流浏览器的代码
2013/03/26 Javascript
深入解析JavaScript中的变量作用域
2013/12/06 Javascript
JS delegate与live浅析
2013/12/21 Javascript
javascript结合canvas实现图片旋转效果
2015/05/03 Javascript
jQuery仿天猫实现超炫的加入购物车
2015/05/04 Javascript
AngularJS基础学习笔记之表达式
2015/05/10 Javascript
JS+CSS实现大气清新的滑动菜单效果代码
2015/10/22 Javascript
jQuery常用知识点总结以及平时封装常用函数
2016/02/23 Javascript
jQuery中on绑定事件后引发的事件冒泡问题如何解决
2016/05/25 Javascript
JQuery+Bootstrap 自定义全屏Loading插件的示例demo
2019/07/03 jQuery
如何优雅地在Node应用中进行错误异常处理
2019/11/25 Javascript
Python实用日期时间处理方法汇总
2015/05/09 Python
解决yum对python依赖版本问题
2019/07/05 Python
Django中如何使用sass的方法步骤
2019/07/09 Python
Python的bit_length函数来二进制的位数方法
2019/08/27 Python
python实现简单成绩录入系统
2019/09/19 Python
python的pyecharts绘制各种图表详细(附代码)
2019/11/11 Python
简单介绍django提供的加密算法
2019/12/18 Python
解决pycharm同一目录下无法import其他文件
2020/02/12 Python
matplotlib quiver箭图绘制案例
2020/04/17 Python
HTML5 Canvas中使用路径描画二阶、三阶贝塞尔曲线
2015/01/01 HTML / CSS
美国排名第一的葡萄酒俱乐部:Firstleaf Wine Club
2020/01/02 全球购物
2014年服务员个人工作总结
2014/12/23 职场文书
2015圣诞节贺卡寄语
2015/03/24 职场文书
地道战观后感2000字
2015/06/04 职场文书
航班延误投诉信
2015/07/02 职场文书
2016年党风廉政建设承诺书
2016/03/25 职场文书
JavaScript严格模式不支持八进制的问题讲解
2021/11/07 Javascript
 Python 中 logging 模块使用详情
2022/03/03 Python
Pyhton爬虫知识之正则表达式详解
2022/04/01 Python