TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
python 调用win32pai 操作cmd的方法
May 28 Python
Python基本数据结构与用法详解【列表、元组、集合、字典】
Mar 23 Python
python读写csv文件方法详细总结
Jul 05 Python
python爬虫解决验证码的思路及示例
Aug 01 Python
python获取array中指定元素的示例
Nov 26 Python
python二维键值数组生成转json的例子
Dec 06 Python
TensorFlow实现打印每一层的输出
Jan 21 Python
详解Python中pyautogui库的最全使用方法
Apr 01 Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 Python
详解Python中的路径问题
Sep 02 Python
python re的findall和finditer的区别详解
Nov 15 Python
浅析Python的命名空间与作用域
Nov 25 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
PHP4实际应用经验篇(1)
2006/10/09 PHP
phpcms模块开发之swfupload的使用介绍
2013/04/28 PHP
基于PHP一些十分严重的缺陷详解
2013/06/03 PHP
PHP网站开发中常用的8个小技巧
2015/02/13 PHP
Zero Clipboard js+swf实现的复制功能使用方法
2010/03/07 Javascript
javascript中的对象创建 实例附注释
2011/02/08 Javascript
JS中prototype关键字的功能介绍及使用示例
2013/07/21 Javascript
深入理解JavaScript系列(26):设计模式之构造函数模式详解
2015/03/03 Javascript
PHP结合jQuery实现红蓝投票功能特效
2015/07/22 Javascript
JavaScript中innerHTML,innerText,outerHTML的用法及区别
2015/09/01 Javascript
jQuery实现自动切换播放的经典滑动门效果
2015/09/12 Javascript
深入学习jQuery中的data()
2016/12/22 Javascript
如何用JS/HTML将时间戳转换为“xx天前”的形式
2017/02/06 Javascript
jQuery插件zTree实现单独选中根节点中第一个节点示例
2017/03/08 Javascript
docker中编译nodejs并使用nginx启动
2017/06/23 NodeJs
JS库之Particles.js中文开发手册及参数详解
2017/09/13 Javascript
jQuery实现下拉菜单动态添加数据点击滑出收起其他功能
2018/06/14 jQuery
详解vuex之store源码简单解析
2019/06/13 Javascript
JavaScript如何获取一个元素的样式信息
2019/07/29 Javascript
新手入门js闭包学习过程解析
2019/10/08 Javascript
js的Object.assign用法示例分析
2020/03/05 Javascript
JS实现手写 forEach算法示例
2020/04/29 Javascript
[38:44]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第二局
2016/02/25 DOTA
python实现的udp协议Server和Client代码实例
2014/06/04 Python
Python实现各种排序算法的代码示例总结
2015/12/11 Python
Python利用ElementTree模块处理XML的方法详解
2017/08/31 Python
python利用pandas将excel文件转换为txt文件的方法
2018/10/23 Python
Python和Java的语法对比分析语法简洁上python的确完美胜出
2019/05/10 Python
详解python深浅拷贝区别
2019/06/24 Python
Python 图像处理: 生成二维高斯分布蒙版的实例
2019/07/04 Python
阿迪达斯中国官网:Adidas中国
2020/12/14 全球购物
乐观自信演讲稿范文
2014/05/21 职场文书
文明城市创建标语
2014/06/16 职场文书
小学教学工作总结2015
2015/05/13 职场文书
描写九月优美句子(39条)
2019/09/11 职场文书
Java Spring 控制反转(IOC)容器详解
2021/10/05 Java/Android