pytorch中Schedule与warmup_steps的用法说明


Posted in Python onMay 24, 2021

1. lr_scheduler相关

lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)

其中args.warmup_steps可以认为是耐心系数

num_train_optimization_steps为模型参数的总更新次数

一般来说:

num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)

Schedule用来调节学习率,拿线性变换调整来说,下面代码中,step是当前迭代次数。

def lr_lambda(self, step):
        # 线性变换,返回的是某个数值x,然后返回到类LambdaLR中,最终返回old_lr*x
        if step < self.warmup_steps: # 增大学习率
            return float(step) / float(max(1, self.warmup_steps))
        # 减小学习率
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))

在实际运行中,lr_scheduler.step()先将lr初始化为0. 在第一次参数更新时,此时step=1,lr由0变为初始值initial_lr;在第二次更新时,step=2,上面代码中生成某个实数alpha,新的lr=initial_lr *alpha;在第三次更新时,新的lr是在initial_lr基础上生成,即新的lr=initial_lr *alpha。

其中warmup_steps可以认为是lr调整的耐心系数。

由于有warmup_steps存在,lr先慢慢增加,超过warmup_steps时,lr再慢慢减小。

在实际中,由于训练刚开始时,训练数据计算出的grad可能与期望方向相反,所以此时采用较小的lr,随着迭代次数增加,lr线性增大,增长率为1/warmup_steps;迭代次数等于warmup_steps时,学习率为初始设定的学习率;迭代次数超过warmup_steps时,学习率逐步衰减,衰减率为1/(total-warmup_steps),再进行微调。

2. gradient_accumulation_steps相关

gradient_accumulation_steps通过累计梯度来解决本地显存不足问题。

假设原来的batch_size=6,样本总量为24,gradient_accumulation_steps=2

那么参数更新次数=24/6=4

现在,减小batch_size=6/2=3,参数更新次数不变=24/3/2=4

在梯度反传时,每gradient_accumulation_steps次进行一次梯度更新,之前照常利用loss.backward()计算梯度。

补充:pytorch学习笔记 -optimizer.step()和scheduler.step()

optimizer.step()和scheduler.step()的区别

optimizer.step()通常用在每个mini-batch之中,而scheduler.step()通常用在epoch里面,但是不绝对,可以根据具体的需求来做。只有用了optimizer.step(),模型才会更新,而scheduler.step()是对lr进行调整。

通常我们有

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)

在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。

所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 输出一个两行字符的变量
Feb 05 Python
Python快速从注释生成文档的方法
Dec 26 Python
Python基于回溯法子集树模板实现8皇后问题
Sep 01 Python
用python处理图片之打开\显示\保存图像的方法
May 04 Python
python实现决策树ID3算法的示例代码
May 30 Python
python 计算平均平方误差(MSE)的实例
Jun 29 Python
python3实现二叉树的遍历与递归算法解析(小结)
Jul 03 Python
python英语单词测试小程序代码实例
Sep 09 Python
python 五子棋如何获得鼠标点击坐标
Nov 04 Python
django rest framework serializers序列化实例
May 13 Python
python 爬虫如何正确的使用cookie
Oct 27 Python
Python爬取网站图片并保存的实现示例
Feb 26 Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
You might like
PHP中GET变量的使用
2006/10/09 PHP
PHP命名空间(Namespace)的使用详解
2013/05/04 PHP
php图像处理函数大全(推荐收藏)
2013/07/11 PHP
phalcon框架使用指南
2016/02/23 PHP
Zend Framework教程之Zend_Controller_Plugin插件用法详解
2016/03/07 PHP
PHP模板引擎Smarty内建函数section,sectionelse用法详解
2016/04/11 PHP
php的laravel框架快速集成微信登录的方法
2016/12/12 PHP
Yii框架模拟组件调用注入示例
2019/11/11 PHP
jQuery 获取/设置/删除DOM元素的属性以a元素为例
2014/05/23 Javascript
jquery中JSON的解析方式
2015/03/16 Javascript
jQuery仿淘宝网产品品牌隐藏与显示效果
2015/09/01 Javascript
JS图片定时翻滚效果实现方法
2016/06/21 Javascript
Nodejs中 npm常用命令详解
2016/07/04 NodeJs
基于JS对象创建常用方式及原理分析
2017/06/28 Javascript
JavaScript实现三级级联特效
2017/11/05 Javascript
JavaScript链式调用实例浅析
2018/12/19 Javascript
Vue过渡效果之CSS过渡详解(结合transition,animation,animate.css)
2020/02/05 Javascript
Vue项目移动端滚动穿透问题的实现
2020/05/19 Javascript
从零学python系列之浅谈pickle模块封装和拆封数据对象的方法
2014/05/23 Python
Python使用遗传算法解决最大流问题
2018/01/29 Python
python实现根据文件关键字进行切分为多个文件的示例
2018/12/10 Python
Python连接SQLite数据库并进行增册改查操作方法详解
2020/02/18 Python
Ubuntu18.04安装 PyCharm并使用 Anaconda 管理的Python环境
2020/04/08 Python
使用Python下载抖音各大V视频的思路详解
2021/02/06 Python
CSS3制作炫酷带方向感应的鼠标滑过图片3D动画
2016/03/16 HTML / CSS
ghd法国官方网站:英国最受欢迎的美发工具品牌
2019/04/18 全球购物
英国最大的汽车配件在线商店:Euro Car Parts
2019/09/30 全球购物
光荣入党自我鉴定
2014/01/22 职场文书
领导干部作风整顿个人剖析材料
2014/10/11 职场文书
神农溪导游词
2015/02/11 职场文书
2015年前台文员工作总结
2015/05/18 职场文书
莫言获奖感言(全文)
2015/07/31 职场文书
解决golang post文件时Content-Type出现的问题
2021/05/02 Golang
OpenCV实现反阈值二值化
2021/11/17 Java/Android
如何利用Python实现n*n螺旋矩阵
2022/01/18 Python
关于python3 opencv 图像二值化的问题(cv2.adaptiveThreshold函数)
2022/04/04 Python