TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
查看Python安装路径以及安装包路径小技巧
Apr 28 Python
python实现按行切分文本文件的方法
Apr 18 Python
python实现图片处理和特征提取详解
Nov 13 Python
python导出hive数据表的schema实例代码
Jan 22 Python
Windows下的Python 3.6.1的下载与安装图文详解(适合32位和64位)
Feb 21 Python
python3.5绘制随机漫步图
Aug 27 Python
django商品分类及商品数据建模实例详解
Jan 03 Python
pycharm运行程序时看不到任何结果显示的解决
Feb 21 Python
python中adb有什么功能
Jun 07 Python
Pandas的Apply函数具体使用
Jul 21 Python
python 绘制正态曲线的示例
Sep 24 Python
python 利用 PIL 将数组值转成图片的实现
Apr 12 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
Windows中安装Apache2和PHP4权威指南
2006/11/18 PHP
php字符串截取中文截取2,单字节截取模式
2007/12/10 PHP
PHP下对数组进行排序的函数
2010/08/08 PHP
WordPress开发中自定义菜单的相关PHP函数使用简介
2016/01/05 PHP
PHP实现腾讯与百度坐标转换
2017/08/05 PHP
js checkbox(复选框) 使用集锦
2009/04/28 Javascript
Javascript的匿名函数小结
2009/12/31 Javascript
改善用户体验的五款jQuery插件分享
2011/05/22 Javascript
script的async属性以非阻塞的模式加载脚本
2013/01/15 Javascript
JavaScript实现N皇后问题算法谜题解答
2014/12/29 Javascript
在JavaScript的jQuery库中操作AJAX的方法讲解
2015/08/15 Javascript
使用jQuery判断Div是否在可视区域的方法 判断div是否可见
2016/02/17 Javascript
JS去除空格和换行的正则表达式(推荐)
2016/06/14 Javascript
JavaScript获取键盘按键的键码(参照表)
2017/01/10 Javascript
使用Webpack提高Vue.js应用的方式汇总(四种)
2017/07/10 Javascript
vue路由拦截及页面跳转的设置方法
2018/05/24 Javascript
微信小程序版本自动更新的方法
2019/06/14 Javascript
vue中实现上传文件给后台实例详解
2019/08/22 Javascript
layer插件实现在弹出层中弹出一警告提示并关闭弹出层的方法
2019/09/24 Javascript
Python简单获取自身外网IP的方法
2016/09/18 Python
Python 异常处理的实例详解
2017/09/11 Python
利用Python正则表达式过滤敏感词的方法
2019/01/21 Python
python爬虫selenium和phantomJs使用方法解析
2019/08/08 Python
关于django 1.10 CSRF验证失败的解决方法
2019/08/31 Python
Django 批量插入数据的实现方法
2020/01/12 Python
Python利用Faiss库实现ANN近邻搜索的方法详解
2020/08/03 Python
python3实现简单飞机大战
2020/11/29 Python
莫斯科的韩国化妆品店:Sifo
2019/12/04 全球购物
百度软件工程师职位
2013/02/14 面试题
十八大报告观后感
2014/01/28 职场文书
李培根演讲稿
2014/05/22 职场文书
乔丹名人堂演讲稿
2014/05/24 职场文书
医院反腐倡廉演讲稿
2014/09/16 职场文书
2014年机关作风建设工作总结
2014/10/23 职场文书
班主任班级管理心得体会
2016/01/07 职场文书
公开致歉信
2019/06/24 职场文书