Keras SGD 随机梯度下降优化器参数设置方式


Posted in Python onJune 19, 2020

SGD 随机梯度下降

Keras 中包含了各式优化器供我们使用,但通常我会倾向于使用 SGD 验证模型能否快速收敛,然后调整不同的学习速率看看模型最后的性能,然后再尝试使用其他优化器。

Keras 中文文档中对 SGD 的描述如下:

keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)

随机梯度下降法,支持动量参数,支持学习衰减率,支持Nesterov动量

参数:

lr:大或等于0的浮点数,学习率

momentum:大或等于0的浮点数,动量参数

decay:大或等于0的浮点数,每次更新后的学习率衰减值

nesterov:布尔值,确定是否使用Nesterov动量

参数设置

Time-Based Learning Rate Schedule

Keras 已经内置了一个基于时间的学习速率调整表,并通过上述参数中的 decay 来实现,学习速率的调整公式如下:

LearningRate = LearningRate * 1/(1 + decay * epoch)

当我们初始化参数为:

LearningRate = 0.1
decay = 0.001

大致变化曲线如下(非实际曲线,仅示意):

Keras SGD 随机梯度下降优化器参数设置方式

当然,方便起见,我们可以将优化器设置如下,使其学习速率随着训练轮次变化:

sgd = SGD(lr=learning_rate, decay=learning_rate/nb_epoch, momentum=0.9, nesterov=True)

Drop-Based Learning Rate Schedule

另外一种学习速率的调整方法思路是保持一个恒定学习速率一段时间后立即降低,是一种突变的方式。通常整个变化趋势为指数形式。

Keras SGD 随机梯度下降优化器参数设置方式

对应的学习速率变化公式如下:

LearningRate = InitialLearningRate * DropRate^floor(Epoch / EpochDrop)

实现需要使用 Keras 中的 LearningRateScheduler 模块:

from keras.callbacks import LearningRateScheduler
# learning rate schedule
def step_decay(epoch):
 initial_lrate = 0.1
 drop = 0.5
 epochs_drop = 10.0
 lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
 return lrate

lrate = LearningRateScheduler(step_decay)

# Compile model
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.compile(loss=..., optimizer=sgd, metrics=['accuracy'])

# Fit the model
model.fit(X, Y, ..., callbacks=[lrate])

补充知识:keras中的BGD和SGD

关于BGD和SGD

首先BGD为批梯度下降,即所有样本计算完毕后才进行梯度更新;而SGD为随机梯度下降,随机计算一次样本就进行梯度下降,所以速度快很多但容易陷入局部最优值。

折中的办法是采用小批的梯度下降,即把数据分成若干个批次,一批来进行一次梯度下降,减少随机性,计算量也不是很大。 mini-batch

keras中的batch_size就是小批梯度下降。

以上这篇Keras SGD 随机梯度下降优化器参数设置方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中用pycurl监控http响应时间脚本分享
Feb 02 Python
Python二分查找详解
Sep 13 Python
Python实现针对中文排序的方法
May 09 Python
Python装饰器实现几类验证功能做法实例
May 18 Python
python 中文件输入输出及os模块对文件系统的操作方法
Aug 27 Python
python将处理好的图像保存到指定目录下的方法
Jan 10 Python
对Django外键关系的描述
Jul 26 Python
softmax及python实现过程解析
Sep 30 Python
使用Python制作一个打字训练小工具
Oct 01 Python
python上传时包含boundary时的解决方法
Apr 08 Python
Jmeter HTTPS接口测试证书导入过程图解
Jul 22 Python
一个非常简单好用的Python图形界面库(PysimpleGUI)
Dec 28 Python
python支持多继承吗
Jun 19 #Python
python和php哪个容易学
Jun 19 #Python
基于Keras中Conv1D和Conv2D的区别说明
Jun 19 #Python
Python魔术方法专题
Jun 19 #Python
关于Theano和Tensorflow多GPU使用问题
Jun 19 #Python
如何对python的字典进行排序
Jun 19 #Python
浅谈Python中的继承
Jun 19 #Python
You might like
php array_slice函数的使用以及参数详解
2008/08/30 PHP
PHP daddslashes 使用方法介绍
2012/10/26 PHP
Codeigniter实现处理用户登录验证后的URL跳转
2014/06/12 PHP
提高 DHTML 页面性能
2006/12/25 Javascript
一个简单的jquery的多选下拉框(自写)
2014/05/05 Javascript
javascript中attribute和property的区别详解
2014/06/05 Javascript
JS辨别访问浏览器判断是android还是ios系统
2014/08/19 Javascript
详解js闭包
2014/09/02 Javascript
基于jQuery实现仿51job城市选择功能实例代码
2016/03/02 Javascript
深入理解JavaScript中的并行处理
2016/09/22 Javascript
javascript汉字拼音互转的简单实例
2016/10/09 Javascript
JS 调用微信扫一扫功能
2016/12/22 Javascript
快速掌握jQuery插件开发
2017/01/19 Javascript
vue 使用自定义指令实现表单校验的方法
2018/08/28 Javascript
vue-cli的build的文件夹下没有dev-server.js文件配置mock数据的方法
2019/04/17 Javascript
react 不用插件实现数字滚动的效果示例
2020/04/14 Javascript
为Python的web框架编写前端模版的教程
2015/04/30 Python
python机器人行走步数问题的解决
2018/01/29 Python
Python数据处理numpy.median的实例讲解
2018/04/02 Python
python操作xlsx文件的包openpyxl实例
2018/05/03 Python
python连接PostgreSQL数据库的过程详解
2019/09/18 Python
Python如何实现强制数据类型转换
2019/11/22 Python
tensorflow中tf.slice和tf.gather切片函数的使用
2020/01/19 Python
python中upper是做什么用的
2020/07/20 Python
如何基于python实现年会抽奖工具
2020/10/20 Python
Django多个app urls配置代码实例
2020/11/26 Python
HTML5学习笔记之History API
2015/02/26 HTML / CSS
Stuart Weitzman美国官网:美国奢华鞋履品牌
2016/08/18 全球购物
Herschel Supply Co.美国:背包、手提袋及配件
2020/11/24 全球购物
教师实习的自我鉴定
2013/10/26 职场文书
2014小学植树节活动总结
2014/03/10 职场文书
地理科学专业自荐信
2014/09/01 职场文书
学习十八大的心得体会
2014/09/12 职场文书
2014业务员年终工作总结
2014/12/09 职场文书
MYSQL如何查看进程和kill进程
2022/03/13 MySQL
Python实现提取PDF简历信息并存入Excel
2022/04/02 Python