Keras SGD 随机梯度下降优化器参数设置方式


Posted in Python onJune 19, 2020

SGD 随机梯度下降

Keras 中包含了各式优化器供我们使用,但通常我会倾向于使用 SGD 验证模型能否快速收敛,然后调整不同的学习速率看看模型最后的性能,然后再尝试使用其他优化器。

Keras 中文文档中对 SGD 的描述如下:

keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)

随机梯度下降法,支持动量参数,支持学习衰减率,支持Nesterov动量

参数:

lr:大或等于0的浮点数,学习率

momentum:大或等于0的浮点数,动量参数

decay:大或等于0的浮点数,每次更新后的学习率衰减值

nesterov:布尔值,确定是否使用Nesterov动量

参数设置

Time-Based Learning Rate Schedule

Keras 已经内置了一个基于时间的学习速率调整表,并通过上述参数中的 decay 来实现,学习速率的调整公式如下:

LearningRate = LearningRate * 1/(1 + decay * epoch)

当我们初始化参数为:

LearningRate = 0.1
decay = 0.001

大致变化曲线如下(非实际曲线,仅示意):

Keras SGD 随机梯度下降优化器参数设置方式

当然,方便起见,我们可以将优化器设置如下,使其学习速率随着训练轮次变化:

sgd = SGD(lr=learning_rate, decay=learning_rate/nb_epoch, momentum=0.9, nesterov=True)

Drop-Based Learning Rate Schedule

另外一种学习速率的调整方法思路是保持一个恒定学习速率一段时间后立即降低,是一种突变的方式。通常整个变化趋势为指数形式。

Keras SGD 随机梯度下降优化器参数设置方式

对应的学习速率变化公式如下:

LearningRate = InitialLearningRate * DropRate^floor(Epoch / EpochDrop)

实现需要使用 Keras 中的 LearningRateScheduler 模块:

from keras.callbacks import LearningRateScheduler
# learning rate schedule
def step_decay(epoch):
 initial_lrate = 0.1
 drop = 0.5
 epochs_drop = 10.0
 lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
 return lrate

lrate = LearningRateScheduler(step_decay)

# Compile model
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.compile(loss=..., optimizer=sgd, metrics=['accuracy'])

# Fit the model
model.fit(X, Y, ..., callbacks=[lrate])

补充知识:keras中的BGD和SGD

关于BGD和SGD

首先BGD为批梯度下降,即所有样本计算完毕后才进行梯度更新;而SGD为随机梯度下降,随机计算一次样本就进行梯度下降,所以速度快很多但容易陷入局部最优值。

折中的办法是采用小批的梯度下降,即把数据分成若干个批次,一批来进行一次梯度下降,减少随机性,计算量也不是很大。 mini-batch

keras中的batch_size就是小批梯度下降。

以上这篇Keras SGD 随机梯度下降优化器参数设置方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python从零实现贝叶斯分类器的机器学习的教程
Mar 31 Python
Python 中 list 的各项操作技巧
Apr 13 Python
python装饰器实例大详解
Oct 25 Python
Python实现的计算马氏距离算法示例
Apr 03 Python
使用python根据端口号关闭进程的方法
Nov 06 Python
PyQt5+requests实现车票查询工具
Jan 21 Python
Python如何获得百度统计API的数据并发送邮件示例代码
Jan 27 Python
python修改linux中文件(文件夹)的权限属性操作
Mar 05 Python
Python中logging日志记录到文件及自动分割的操作代码
Aug 05 Python
如何用tempfile库创建python进程中的临时文件
Jan 28 Python
pytorch交叉熵损失函数的weight参数的使用
May 24 Python
python区块链实现简版工作量证明
May 25 Python
python支持多继承吗
Jun 19 #Python
python和php哪个容易学
Jun 19 #Python
基于Keras中Conv1D和Conv2D的区别说明
Jun 19 #Python
Python魔术方法专题
Jun 19 #Python
关于Theano和Tensorflow多GPU使用问题
Jun 19 #Python
如何对python的字典进行排序
Jun 19 #Python
浅谈Python中的继承
Jun 19 #Python
You might like
PHP4中session登录页面的应用
2008/07/25 PHP
国产PHP开发框架myqee新手快速入门教程
2014/07/14 PHP
原生JS实现Ajax通过POST方式与PHP进行交互的方法示例
2018/05/12 PHP
Ajax+Json 级联菜单实现代码
2009/10/27 Javascript
javascript 哈希表(hashtable)的简单实现
2010/01/20 Javascript
JS使用ajax方法获取指定url的head信息中指定字段值的方法
2015/03/24 Javascript
JavaScript实现强制重定向至HTTPS页面
2015/06/10 Javascript
JavaScript获取表格(table)当前行的值、删除行、增加行
2015/07/03 Javascript
JavaScript基本数据类型及值类型和引用类型
2015/08/25 Javascript
字符串反转_JavaScript
2016/04/28 Javascript
详解Backbone.js框架中的模型Model与其集合collection
2016/05/05 Javascript
ajax实现动态下拉框示例
2017/01/10 Javascript
localStorage的黑科技-js和css缓存机制
2017/02/06 Javascript
Angular 开发学习之Angular CLI的安装使用
2017/12/31 Javascript
vue源码学习之Object.defineProperty 对数组监听
2018/05/30 Javascript
微信小程序实现自定义picker选择器弹窗内容
2020/05/26 Javascript
vue获取元素宽、高、距离左边距离,右,上距离等还有XY坐标轴的方法
2018/09/05 Javascript
在vue中解决提示警告 for循环报错的方法
2018/09/28 Javascript
微信{"errcode":48001,"errmsg":"api unauthorized, hints: [ req_id: 1QoCla0699ns81 ]"}
2018/10/12 Javascript
微信小程序学习笔记之目录结构、基本配置图文详解
2019/03/28 Javascript
Python查找两个有序列表中位数的方法【基于归并算法】
2018/04/20 Python
windows下python 3.6.4安装配置图文教程
2018/08/21 Python
python使用正则表达式来获取文件名的前缀方法
2018/10/21 Python
Python字符串中添加、插入特定字符的方法
2019/09/10 Python
Python常用外部指令执行代码实例
2020/11/05 Python
奥地利票务门户网站:oeticket.com
2019/12/31 全球购物
新锐科技Java程序员面试题
2016/07/25 面试题
体育老师的教学自我评价分享
2013/11/19 职场文书
艺术设计专业求职自荐信
2014/05/19 职场文书
户籍证明格式
2014/09/15 职场文书
2014年监理工作总结范文
2014/11/17 职场文书
就业指导讲座心得体会
2016/01/15 职场文书
2016年共产党员公开承诺书
2016/03/24 职场文书
个人道歉信大全
2019/04/11 职场文书
SpringBoot快速入门详解
2021/07/21 Java/Android
nginx负载功能+nfs服务器功能解析
2022/02/28 Servers