TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
Python高级应用实例对比:高效计算大文件中的最长行的长度
Jun 08 Python
详解Python中with语句的用法
Apr 15 Python
对Python中数组的几种使用方法总结
Jun 28 Python
python遍历文件夹找出文件夹后缀为py的文件方法
Oct 21 Python
Python socket 套接字实现通信详解
Aug 27 Python
Python实现图片添加文字
Nov 26 Python
Python 矩阵转置的几种方法小结
Dec 02 Python
python实现FTP文件传输的方法(服务器端和客户端)
Mar 20 Python
django queryset相加和筛选教程
May 18 Python
Python字符串格式化常用手段及注意事项
Jun 17 Python
详解numpy1.19.4与python3.9版本冲突解决
Dec 15 Python
python实现自动化群控的步骤
Apr 11 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
smtp邮件发送一例
2006/10/09 PHP
PHP 中关于ord($str)>0x80的详细说明
2012/09/23 PHP
FastCGI 进程意外退出造成500错误
2015/07/26 PHP
PHP正则表达式过滤html标签属性(DEMO)
2016/05/04 PHP
javascript基础的动画教程,直观易懂
2007/01/10 Javascript
Javascript 表单之间的数据传递代码
2008/12/04 Javascript
javascript 获取图片颜色
2009/04/05 Javascript
基于JQuery的密码强度验证代码
2010/03/01 Javascript
Web 前端设计模式--Dom重构 提高显示性能
2010/10/22 Javascript
jQuery动态添加 input type=file的实现代码
2012/06/14 Javascript
在jquery中combobox多选的不兼容问题总结
2013/12/24 Javascript
JavaScript调用ajax获取文本文件内容实现代码
2014/03/28 Javascript
JavaScript中获取样式的原生方法小结
2014/10/08 Javascript
js中的内部属性与delete操作符介绍
2015/08/10 Javascript
javascript多物体运动实现方法分析
2016/01/08 Javascript
深入解析Backbone.js框架的依赖库Underscore.js的作用
2016/05/07 Javascript
你不需要jQuery(三) 新AJAX方法fetch()
2016/06/14 Javascript
浅谈js内置对象Math的属性和方法(推荐)
2016/09/19 Javascript
Bootstrap 实现查询的完美方法
2016/10/26 Javascript
使用Ajax和Jquery配合数据库实现下拉框的二级联动的示例
2018/01/25 jQuery
Vue.js 动态为img的src赋值方法
2018/03/14 Javascript
vue-cli3.0配置及使用注意事项详解
2018/09/05 Javascript
详解easyui 切换主题皮肤
2019/04/04 Javascript
有趣的JavaScript隐式类型转换操作实例分析
2020/05/02 Javascript
python根据距离和时长计算配速示例
2014/02/16 Python
Python 的类、继承和多态详解
2017/07/16 Python
python添加模块搜索路径方法
2017/09/11 Python
Python3 单行多行万能正则匹配方法
2019/01/07 Python
Python面向对象进阶学习
2019/05/21 Python
Python Sympy计算梯度、散度和旋度的实例
2019/12/06 Python
Python求凸包及多边形面积教程
2020/04/12 Python
Python基于jieba, wordcloud库生成中文词云
2020/05/13 Python
追悼会子女答谢词
2014/01/28 职场文书
领导干部个人对照检查材料(群众路线)
2014/09/26 职场文书
JavaScript 防篡改对象的用法示例
2021/04/24 Javascript
MySQL获取所有分类的前N条记录
2021/05/07 MySQL