TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
Django的session中对于用户验证的支持
Jul 23 Python
整理Python 常用string函数(收藏)
May 30 Python
详解Python读取配置文件模块ConfigParser
May 11 Python
Python使用内置json模块解析json格式数据的方法
Jul 20 Python
python学习教程之Numpy和Pandas的使用
Sep 11 Python
Python Nose框架编写测试用例方法
Oct 26 Python
使用django-crontab实现定时任务的示例
Feb 26 Python
使用 Python 实现文件递归遍历的三种方式
Jul 18 Python
python 在指定范围内随机生成不重复的n个数实例
Jan 28 Python
Python爬虫requests库多种用法实例
May 28 Python
python上selenium的弹框操作实现
Jul 13 Python
Pytorch 如何实现LSTM时间序列预测
May 17 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
怎么使 Mysql 数据同步
2006/10/09 PHP
PHP计划任务、定时执行任务的实现代码
2011/04/23 PHP
php筛选不存在的图片资源
2015/04/28 PHP
php遍历解析xml字符串的方法
2016/05/05 PHP
PHP两个n位的二进制整数相加问题的解决
2018/08/26 PHP
php5对象复制、clone、浅复制与深复制实例详解
2019/08/14 PHP
PHP 图片合成、仿微信群头像的方法示例
2019/10/25 PHP
用cssText批量修改样式
2009/08/29 Javascript
IE6-IE9中tbody的innerHTML不能赋值的解决方法
2014/06/05 Javascript
图片放大镜jquery.jqzoom.js使用实例附放大镜图标
2014/06/19 Javascript
jQuery实现信息提示框(带有圆角框与动画)效果
2015/08/07 Javascript
JS实现鼠标滑过折叠与展开菜单效果代码
2015/09/06 Javascript
微信公众号支付H5调用支付解析
2016/11/04 Javascript
javascript图片预览和上传(兼容IE)
2017/03/15 Javascript
JQuery EasyUI的一些常用组件
2017/07/12 jQuery
Vue通过URL传参如何控制全局console.log的开关详解
2017/12/07 Javascript
Webpack框架核心概念(知识点整理)
2017/12/22 Javascript
VueJs组件之父子通讯的方式
2018/05/06 Javascript
Bootstrap 按钮样式与使用代码详解
2018/12/09 Javascript
Element实现表格嵌套、多个表格共用一个表头的方法
2020/05/09 Javascript
Vue使用路由钩子拦截器beforeEach和afterEach监听路由
2020/11/16 Javascript
在Python的Flask框架中使用日期和时间的教程
2015/04/21 Python
Python判断Abundant Number的方法
2015/06/15 Python
python实现多线程抓取知乎用户
2016/12/12 Python
django限制匿名用户访问及重定向的方法实例
2018/02/07 Python
python pands实现execl转csv 并修改csv指定列的方法
2018/12/12 Python
python 解决cv2绘制中文乱码问题
2019/12/23 Python
python代码xml转txt实例
2020/03/10 Python
python numpy矩阵信息说明,shape,size,dtype
2020/05/22 Python
css3实现的下拉菜单效果示例
2014/01/22 HTML / CSS
CSS3媒体查询Media Queries基础学习教程
2016/02/29 HTML / CSS
HTML5本地存储之Web Storage应用介绍
2013/01/06 HTML / CSS
AmazeUI底部导航栏与分享按钮的示例代码
2020/08/18 HTML / CSS
妈妈活动方案
2014/08/15 职场文书
辩护词范文大全
2015/05/21 职场文书
离婚被告代理词
2015/05/23 职场文书