TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
python中 ? : 三元表达式的使用介绍
Oct 09 Python
Python中isnumeric()方法的使用简介
May 19 Python
Django中模版的子目录与include标签的使用方法
Jul 16 Python
Python实现拷贝多个文件到同一目录的方法
Sep 19 Python
利用python生成一个导出数据库的bat脚本文件的方法
Dec 30 Python
详解如何在Apache中运行Python WSGI应用
Jan 02 Python
python获取array中指定元素的示例
Nov 26 Python
python os.path.isfile 的使用误区详解
Nov 29 Python
Python argparse模块使用方法解析
Feb 20 Python
pygame实现飞机大战
Mar 11 Python
Python多线程操作之互斥锁、递归锁、信号量、事件实例详解
Mar 24 Python
Python实现Word文档转换Markdown的示例
Dec 22 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
广播爱好者需要了解的天线知识
2021/03/01 无线电
比较strtr, str_replace和preg_replace三个函数的效率
2013/06/26 PHP
Yii控制器中操作视图js的方法
2016/07/04 PHP
PHP实现的超长文本分页显示功能示例
2018/06/04 PHP
JavaScript Object的extend是一个常用的功能
2009/12/02 Javascript
js,jQuery 排序的实现代码,网页标签排序的实现,标签排序
2011/04/27 Javascript
关于在IE下的一个安全BUG --可用于跟踪用户的系统鼠标位置
2013/04/17 Javascript
JavaScript判断变量是否为undefined的两种写法区别
2013/12/04 Javascript
Javascript加载速度慢的解决方案
2014/03/11 Javascript
基于jquery实现轮播特效
2016/04/22 Javascript
javascript实现获取指定精度的上传文件的大小简单实例
2016/10/25 Javascript
jQuery实现倒计时重新发送短信验证码功能示例
2017/01/12 Javascript
详解Angular-cli生成组件修改css成less或sass的实例
2017/07/27 Javascript
node通过express搭建自己的服务器
2017/09/30 Javascript
微信小程序项目实践之九宫格实现及item跳转功能
2018/07/19 Javascript
详解Angular6 热加载配置方案
2018/08/18 Javascript
vue-router之实现导航切换过渡动画效果
2019/10/31 Javascript
js实现列表向上无限滚动
2020/01/13 Javascript
vue实现简单跑马灯效果
2020/05/25 Javascript
vue实现员工信息录入功能
2020/06/11 Javascript
vant 中van-list的用法说明
2020/11/11 Javascript
windows下ipython的安装与使用详解
2016/10/20 Python
rabbitmq(中间消息代理)在python中的使用详解
2017/12/14 Python
python脚本开机自启的实现方法
2019/06/28 Python
Python 线程池用法简单示例
2019/10/02 Python
将tensorflow.Variable中的某些元素取出组成一个新的矩阵示例
2020/01/04 Python
20行代码教你用python给证件照换底色的方法示例
2021/02/05 Python
HTML5实现页面切换激活的PageVisibility API使用初探
2016/05/13 HTML / CSS
美国排名第一的泳池用品直接来源:In The Swim
2019/09/23 全球购物
网游商务专员求职信
2013/10/15 职场文书
毕业生的求职信范文分享
2013/12/04 职场文书
幼儿园教师的自我评价范文
2014/09/17 职场文书
鸦片战争观后感
2015/06/09 职场文书
雷锋观后感
2015/06/10 职场文书
曾国藩励志经典名言37句,蕴含哲理
2019/10/14 职场文书
Win11 vmware不兼容怎么办?Win11与VMware虚拟机不兼容的解决方法
2023/01/09 数码科技