Keras SGD 随机梯度下降优化器参数设置方式


Posted in Python onJune 19, 2020

SGD 随机梯度下降

Keras 中包含了各式优化器供我们使用,但通常我会倾向于使用 SGD 验证模型能否快速收敛,然后调整不同的学习速率看看模型最后的性能,然后再尝试使用其他优化器。

Keras 中文文档中对 SGD 的描述如下:

keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)

随机梯度下降法,支持动量参数,支持学习衰减率,支持Nesterov动量

参数:

lr:大或等于0的浮点数,学习率

momentum:大或等于0的浮点数,动量参数

decay:大或等于0的浮点数,每次更新后的学习率衰减值

nesterov:布尔值,确定是否使用Nesterov动量

参数设置

Time-Based Learning Rate Schedule

Keras 已经内置了一个基于时间的学习速率调整表,并通过上述参数中的 decay 来实现,学习速率的调整公式如下:

LearningRate = LearningRate * 1/(1 + decay * epoch)

当我们初始化参数为:

LearningRate = 0.1
decay = 0.001

大致变化曲线如下(非实际曲线,仅示意):

Keras SGD 随机梯度下降优化器参数设置方式

当然,方便起见,我们可以将优化器设置如下,使其学习速率随着训练轮次变化:

sgd = SGD(lr=learning_rate, decay=learning_rate/nb_epoch, momentum=0.9, nesterov=True)

Drop-Based Learning Rate Schedule

另外一种学习速率的调整方法思路是保持一个恒定学习速率一段时间后立即降低,是一种突变的方式。通常整个变化趋势为指数形式。

Keras SGD 随机梯度下降优化器参数设置方式

对应的学习速率变化公式如下:

LearningRate = InitialLearningRate * DropRate^floor(Epoch / EpochDrop)

实现需要使用 Keras 中的 LearningRateScheduler 模块:

from keras.callbacks import LearningRateScheduler
# learning rate schedule
def step_decay(epoch):
 initial_lrate = 0.1
 drop = 0.5
 epochs_drop = 10.0
 lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
 return lrate

lrate = LearningRateScheduler(step_decay)

# Compile model
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.compile(loss=..., optimizer=sgd, metrics=['accuracy'])

# Fit the model
model.fit(X, Y, ..., callbacks=[lrate])

补充知识:keras中的BGD和SGD

关于BGD和SGD

首先BGD为批梯度下降,即所有样本计算完毕后才进行梯度更新;而SGD为随机梯度下降,随机计算一次样本就进行梯度下降,所以速度快很多但容易陷入局部最优值。

折中的办法是采用小批的梯度下降,即把数据分成若干个批次,一批来进行一次梯度下降,减少随机性,计算量也不是很大。 mini-batch

keras中的batch_size就是小批梯度下降。

以上这篇Keras SGD 随机梯度下降优化器参数设置方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
微信小程序跳一跳游戏 python脚本跳一跳刷高分技巧
Jan 04 Python
python批量修改图片大小的方法
Jul 24 Python
对DataFrame数据中的重复行,利用groupby累加合并的方法详解
Jan 30 Python
详解Python匿名函数(lambda函数)
Apr 19 Python
Python实现的远程文件自动打包并下载功能示例
Jul 12 Python
Python Web程序搭建简单的Web服务器
Jul 31 Python
Python enumerate函数遍历数据对象组合过程解析
Dec 11 Python
Django Admin设置应用程序及模型顺序方法详解
Apr 01 Python
python3 使用openpyxl将mysql数据写入xlsx的操作
May 15 Python
python如何更新包
Jun 11 Python
Django filter动态过滤与排序实现过程解析
Nov 26 Python
用Python写一个简易版弹球游戏
Apr 13 Python
python支持多继承吗
Jun 19 #Python
python和php哪个容易学
Jun 19 #Python
基于Keras中Conv1D和Conv2D的区别说明
Jun 19 #Python
Python魔术方法专题
Jun 19 #Python
关于Theano和Tensorflow多GPU使用问题
Jun 19 #Python
如何对python的字典进行排序
Jun 19 #Python
浅谈Python中的继承
Jun 19 #Python
You might like
php支付宝手机网页支付类实例
2015/03/04 PHP
javascript 类定义的4种方法
2009/09/12 Javascript
js 获取子节点函数 (兼容FF与IE)
2010/04/18 Javascript
Javascript中数组sort和reverse用法分析
2014/12/30 Javascript
走进AngularJs之过滤器(filter)详解
2017/02/17 Javascript
js获取当前周、上一周、下一周日期
2017/03/19 Javascript
详解Vue2.X的路由管理记录之 钩子函数(切割流水线)
2017/05/02 Javascript
vue框架中props的typescript用法详解
2020/02/17 Javascript
Jquery高级应用Deferred对象原理及使用实例
2020/05/28 jQuery
JavaScript中的全局属性与方法深入解析
2020/06/14 Javascript
js实现金山打字通小游戏
2020/07/24 Javascript
python连接mysql并提交mysql事务示例
2014/03/05 Python
python实现批量获取指定文件夹下的所有文件的厂商信息
2014/09/28 Python
python中requests小技巧
2017/05/10 Python
Python开发的十个小贴士和技巧及长常犯错误
2018/09/27 Python
对python生成业务报表的实例详解
2019/02/03 Python
python判断无向图环是否存在的示例
2019/11/22 Python
利用python实现逐步回归
2020/02/24 Python
django之导入并执行自定义的函数模块图解
2020/04/01 Python
浅谈keras使用中val_acc和acc值不同步的思考
2020/06/18 Python
Python如何实现感知器的逻辑电路
2020/12/25 Python
英国最大的在线奢侈手表零售商:Jura Watches
2018/01/29 全球购物
Auguste The Label官网:澳大利亚一家精品女装时尚品牌
2020/06/14 全球购物
Java的基础面试题附答案
2016/01/10 面试题
试述DBMS的主要功能
2016/11/13 面试题
什么是Linux虚拟文件系统VFS
2015/08/25 面试题
EJB发布WEB服务一般步骤
2012/10/31 面试题
30岁生日感言
2014/01/25 职场文书
红领巾广播站广播稿(3篇)
2014/09/20 职场文书
2014年幼儿园教学工作总结
2014/12/04 职场文书
详解thinkphp的Auth类认证
2021/05/28 PHP
Python中requests做接口测试的方法
2021/05/30 Python
Python集合的基础操作
2021/11/01 Python
Go语言基础map用法及示例详解
2021/11/17 Golang
详细介绍python操作RabbitMq
2022/04/12 Python
MySQL三种方式实现递归查询
2022/04/18 MySQL