Keras SGD 随机梯度下降优化器参数设置方式


Posted in Python onJune 19, 2020

SGD 随机梯度下降

Keras 中包含了各式优化器供我们使用,但通常我会倾向于使用 SGD 验证模型能否快速收敛,然后调整不同的学习速率看看模型最后的性能,然后再尝试使用其他优化器。

Keras 中文文档中对 SGD 的描述如下:

keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)

随机梯度下降法,支持动量参数,支持学习衰减率,支持Nesterov动量

参数:

lr:大或等于0的浮点数,学习率

momentum:大或等于0的浮点数,动量参数

decay:大或等于0的浮点数,每次更新后的学习率衰减值

nesterov:布尔值,确定是否使用Nesterov动量

参数设置

Time-Based Learning Rate Schedule

Keras 已经内置了一个基于时间的学习速率调整表,并通过上述参数中的 decay 来实现,学习速率的调整公式如下:

LearningRate = LearningRate * 1/(1 + decay * epoch)

当我们初始化参数为:

LearningRate = 0.1
decay = 0.001

大致变化曲线如下(非实际曲线,仅示意):

Keras SGD 随机梯度下降优化器参数设置方式

当然,方便起见,我们可以将优化器设置如下,使其学习速率随着训练轮次变化:

sgd = SGD(lr=learning_rate, decay=learning_rate/nb_epoch, momentum=0.9, nesterov=True)

Drop-Based Learning Rate Schedule

另外一种学习速率的调整方法思路是保持一个恒定学习速率一段时间后立即降低,是一种突变的方式。通常整个变化趋势为指数形式。

Keras SGD 随机梯度下降优化器参数设置方式

对应的学习速率变化公式如下:

LearningRate = InitialLearningRate * DropRate^floor(Epoch / EpochDrop)

实现需要使用 Keras 中的 LearningRateScheduler 模块:

from keras.callbacks import LearningRateScheduler
# learning rate schedule
def step_decay(epoch):
 initial_lrate = 0.1
 drop = 0.5
 epochs_drop = 10.0
 lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
 return lrate

lrate = LearningRateScheduler(step_decay)

# Compile model
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.compile(loss=..., optimizer=sgd, metrics=['accuracy'])

# Fit the model
model.fit(X, Y, ..., callbacks=[lrate])

补充知识:keras中的BGD和SGD

关于BGD和SGD

首先BGD为批梯度下降,即所有样本计算完毕后才进行梯度更新;而SGD为随机梯度下降,随机计算一次样本就进行梯度下降,所以速度快很多但容易陷入局部最优值。

折中的办法是采用小批的梯度下降,即把数据分成若干个批次,一批来进行一次梯度下降,减少随机性,计算量也不是很大。 mini-batch

keras中的batch_size就是小批梯度下降。

以上这篇Keras SGD 随机梯度下降优化器参数设置方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python抓取网页图片并放到指定文件夹
Apr 24 Python
Python 列表(List) 的三种遍历方法实例 详解
Apr 15 Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 Python
Python 给某个文件名添加时间戳的方法
Oct 16 Python
python3使用腾讯企业邮箱发送邮件的实例
Jun 28 Python
python3 selenium自动化测试 强大的CSS定位方法
Aug 23 Python
python pygame实现滚动横版射击游戏城市之战
Nov 25 Python
Pycharm及python安装详细步骤及PyCharm配置整理(推荐)
Jul 31 Python
pytorch 查看cuda 版本方式
Jun 23 Python
解决Python安装cryptography报错问题
Sep 03 Python
python实现人性化显示金额数字实例详解
Sep 25 Python
python 实现全球IP归属地查询工具
Dec 18 Python
python支持多继承吗
Jun 19 #Python
python和php哪个容易学
Jun 19 #Python
基于Keras中Conv1D和Conv2D的区别说明
Jun 19 #Python
Python魔术方法专题
Jun 19 #Python
关于Theano和Tensorflow多GPU使用问题
Jun 19 #Python
如何对python的字典进行排序
Jun 19 #Python
浅谈Python中的继承
Jun 19 #Python
You might like
PHP 七大优势分析
2009/06/23 PHP
php过滤所有恶意字符(批量过滤post,get敏感数据)
2014/03/18 PHP
php版微信公众平台接口参数调试实现判断用户行为的方法
2016/09/23 PHP
Yii2压缩PHP中模板代码的输出问题
2018/08/28 PHP
PHP PDOStatement::execute讲解
2019/01/31 PHP
PHP实现唤起微信支付功能
2019/02/18 PHP
javascript编程起步(第三课)
2007/02/27 Javascript
Javascript事件热键兼容ie|firefox
2010/12/30 Javascript
Javascript中定义方法的另类写法(批量定义js对象的方法)
2011/02/25 Javascript
JS中Iframe之间传值的方法
2013/03/11 Javascript
jQuery 追加元素的方法如append、prepend、before
2014/01/16 Javascript
javascript向后台传送相同属性的参数即数组参数
2014/02/17 Javascript
jquery实现人性化的有选择性禁用鼠标右键
2014/06/30 Javascript
JavaScript中的style.cssText使用教程
2014/11/06 Javascript
node.js中的http.response.addTrailers方法使用说明
2014/12/14 Javascript
jQuery实现拖动调整表格单元格大小的代码实例
2015/01/13 Javascript
JavaScript中的getMilliseconds()方法使用详解
2015/06/10 Javascript
JS扩展类,克隆对象与混合类实例分析
2016/11/26 Javascript
深入理解Javascript箭头函数中的this
2017/02/13 Javascript
基于JavaScript实现评论框展开和隐藏功能
2017/08/25 Javascript
nodejs对express中next函数的一些理解
2017/09/08 NodeJs
详解JavaScript中typeof与instanceof用法
2018/10/24 Javascript
JS实现根据数组对象的某一属性排序操作示例
2019/01/14 Javascript
js实现转动骰子模型
2019/10/24 Javascript
js前端如何写一个精确的倒计时代码
2019/10/25 Javascript
微信小程序学习总结(五)常见问题实例小结
2020/06/04 Javascript
Python中实现两个字典(dict)合并的方法
2014/09/23 Python
python中__call__内置函数用法实例
2015/06/04 Python
Python实现按逗号分隔列表的方法
2018/10/23 Python
Laravel框架表单验证格式化输出的方法
2019/09/25 Python
python绘图模块之利用turtle画图
2021/02/12 Python
利用Opencv实现图片的油画特效实例
2021/02/28 Python
HTML5开发动态音频图的实现
2020/07/02 HTML / CSS
俄罗斯商务邀请函
2014/01/26 职场文书
体育教师求职信
2014/05/24 职场文书
个性发展自我评价2015
2015/03/09 职场文书