TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
Python使用scrapy采集数据过程中放回下载过大页面的方法
Apr 08 Python
在Python的Django框架下使用django-tagging的教程
May 30 Python
深入讲解Python中的迭代器和生成器
Oct 26 Python
浅谈numpy数组中冒号和负号的含义
Apr 18 Python
Python中安装easy_install的方法
Nov 18 Python
pyqt5实现俄罗斯方块游戏
Jan 11 Python
对Django 中request.get和request.post的区别详解
Aug 12 Python
Tensorflow读取并输出已保存模型的权重数值方式
Jan 04 Python
Jupyter notebook快速入门教程(推荐)
May 18 Python
Python描述数据结构学习之哈夫曼树篇
Sep 07 Python
matplotlib源码解析标题实现(窗口标题,标题,子图标题不同之间的差异)
Feb 22 Python
Python使用BeautifulSoup4修改网页内容
May 20 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
PHP 程序员的调试技术小结
2009/11/15 PHP
部署PHP项目应该注意的几点事项分享
2013/12/20 PHP
PHP获取windows登录用户名的方法
2014/06/24 PHP
对于ThinkPHP框架早期版本的一个SQL注入漏洞详细分析
2014/07/04 PHP
thinkphp微信开之安全模式消息加密解密不成功的解决办法
2015/12/02 PHP
laravel框架上传图片实现实时预览功能
2019/10/14 PHP
深入理解JavaScript中的传值与传引用
2013/12/09 Javascript
js setTimeout()函数介绍及应用以倒计时为例
2013/12/12 Javascript
KnockoutJs快速入门教程
2016/05/16 Javascript
详解node child_process模块学习笔记
2018/01/24 Javascript
深入了解javascript 数组的sort方法
2018/06/01 Javascript
JS实现统计字符串中字符出现个数及最大个数功能示例
2018/06/04 Javascript
Javascript 之封装(Package)
2018/09/14 Javascript
js实现带积分弹球小游戏
2020/07/21 Javascript
VUE项目axios请求头更改Content-Type操作
2020/07/24 Javascript
javascript实现一款好看的秒表计时器
2020/09/05 Javascript
Python中的defaultdict与__missing__()使用介绍
2018/02/03 Python
浅谈Pandas 排序之后索引的问题
2018/06/07 Python
python2.7实现邮件发送功能
2018/12/12 Python
Python minidom模块用法示例【DOM写入和解析XML】
2019/03/25 Python
通过pycharm使用git的步骤(图文详解)
2019/06/13 Python
python实现串口通信的示例代码
2020/02/10 Python
德国机场停车位比较和预订网站:Ich-parke-billiger
2018/01/08 全球购物
Fox Racing英国官网:越野摩托车和山地自行车服装
2020/02/26 全球购物
北京-环亚运商测试题.net程序员初步测试题
2013/05/28 面试题
介绍一下Ruby的多线程处理
2013/02/01 面试题
商务专员岗位职责
2013/11/23 职场文书
奥巴马当选演讲稿
2014/09/10 职场文书
自愿离婚协议书范本
2014/09/13 职场文书
钳工实训报告总结
2014/11/04 职场文书
2015年信息中心工作总结
2015/05/25 职场文书
2015年学校禁毒工作总结
2015/05/27 职场文书
三十年同学聚会感言
2015/07/30 职场文书
护理心得体会范文
2016/01/22 职场文书
一文了解MySQL二级索引的查询过程
2022/02/24 MySQL
《废话连篇——致新手》——chinapizza
2022/04/05 无线电