pytorch中Schedule与warmup_steps的用法说明


Posted in Python onMay 24, 2021

1. lr_scheduler相关

lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)

其中args.warmup_steps可以认为是耐心系数

num_train_optimization_steps为模型参数的总更新次数

一般来说:

num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)

Schedule用来调节学习率,拿线性变换调整来说,下面代码中,step是当前迭代次数。

def lr_lambda(self, step):
        # 线性变换,返回的是某个数值x,然后返回到类LambdaLR中,最终返回old_lr*x
        if step < self.warmup_steps: # 增大学习率
            return float(step) / float(max(1, self.warmup_steps))
        # 减小学习率
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))

在实际运行中,lr_scheduler.step()先将lr初始化为0. 在第一次参数更新时,此时step=1,lr由0变为初始值initial_lr;在第二次更新时,step=2,上面代码中生成某个实数alpha,新的lr=initial_lr *alpha;在第三次更新时,新的lr是在initial_lr基础上生成,即新的lr=initial_lr *alpha。

其中warmup_steps可以认为是lr调整的耐心系数。

由于有warmup_steps存在,lr先慢慢增加,超过warmup_steps时,lr再慢慢减小。

在实际中,由于训练刚开始时,训练数据计算出的grad可能与期望方向相反,所以此时采用较小的lr,随着迭代次数增加,lr线性增大,增长率为1/warmup_steps;迭代次数等于warmup_steps时,学习率为初始设定的学习率;迭代次数超过warmup_steps时,学习率逐步衰减,衰减率为1/(total-warmup_steps),再进行微调。

2. gradient_accumulation_steps相关

gradient_accumulation_steps通过累计梯度来解决本地显存不足问题。

假设原来的batch_size=6,样本总量为24,gradient_accumulation_steps=2

那么参数更新次数=24/6=4

现在,减小batch_size=6/2=3,参数更新次数不变=24/3/2=4

在梯度反传时,每gradient_accumulation_steps次进行一次梯度更新,之前照常利用loss.backward()计算梯度。

补充:pytorch学习笔记 -optimizer.step()和scheduler.step()

optimizer.step()和scheduler.step()的区别

optimizer.step()通常用在每个mini-batch之中,而scheduler.step()通常用在epoch里面,但是不绝对,可以根据具体的需求来做。只有用了optimizer.step(),模型才会更新,而scheduler.step()是对lr进行调整。

通常我们有

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)

在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。

所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进程管理工具supervisor的安装与使用教程
Sep 05 Python
Python使用getpass库读取密码的示例
Oct 10 Python
python使用jieba实现中文分词去停用词方法示例
Mar 11 Python
python命令行参数用法实例分析
Jun 25 Python
TensorFlow实现从txt文件读取数据
Feb 05 Python
python GUI库图形界面开发之PyQt5时间控件QTimer详细使用方法与实例
Feb 26 Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 Python
基于python SMTP实现自动发送邮件教程解析
Jun 02 Python
对python中list的五种查找方法说明
Jul 13 Python
Python join()函数原理及使用方法
Nov 14 Python
Python环境配置实现pip加速过程解析
Nov 27 Python
python 中[0]*2与0*2的区别说明
May 10 Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
You might like
使用php判断浏览器的类型和语言的函数代码
2013/02/28 PHP
php解决约瑟夫环示例
2014/04/09 PHP
Linux环境下php实现给网站截图的方法
2016/05/03 PHP
一款JavaScript压缩工具:X2JSCompactor
2007/06/13 Javascript
对字符串进行HTML编码和解码的JavaScript函数
2010/02/01 Javascript
jquery ajax 局部刷新小案例
2014/02/08 Javascript
解决js下referer兼容各大浏览器的方法
2014/11/03 Javascript
AngularJS入门教程(零):引导程序
2014/12/06 Javascript
针对初学者的jQuery入门指南
2015/08/15 Javascript
JS中使用FormData上传文件、图片的方法
2016/08/07 Javascript
Bootstrap CSS布局之表格
2016/12/17 Javascript
JavaScript使用readAsDataUrl方法预览图片
2017/05/10 Javascript
jQuery选取所有复选框被选中的值并用Ajax异步提交数据的实例
2017/08/04 jQuery
js+css实现红包雨效果
2018/07/12 Javascript
jQuery移动端跑马灯抽奖特效升级版(抽奖概率固定)实现方法
2019/01/18 jQuery
弱类型语言javascript开发中的一些坑实例小结【变量、函数、数组、对象、作用域等】
2019/08/07 Javascript
Vue el-autocomplete远程搜索下拉框并实现自动填充功能(推荐)
2019/10/25 Javascript
vue-cli3自动消除console.log()的调试信息方式
2020/10/21 Javascript
[01:33]完美世界DOTA2联赛PWL S3 集锦第二期
2020/12/21 DOTA
python原始套接字编程示例分享
2014/02/21 Python
Python中list列表的一些进阶使用方法介绍
2015/08/15 Python
Python实现OpenCV的安装与使用示例
2018/03/30 Python
Python使用win32 COM实现Excel的写入与保存功能示例
2018/05/03 Python
pycharm远程linux开发和调试代码的方法
2018/07/17 Python
Python错误处理操作示例
2018/07/18 Python
python3实现钉钉消息推送的方法示例
2019/03/14 Python
使用Python实现跳帧截取视频帧
2019/05/31 Python
selenium+PhantomJS爬取豆瓣读书
2019/08/26 Python
VSCODE配置Markdown及Markdown基础语法详解
2021/01/19 Python
中软国际Java程序员笔试题
2014/07/19 面试题
法学专业毕业生自荐信范文
2013/12/18 职场文书
篝火晚会主持词
2014/03/25 职场文书
大学生团日活动总结
2015/05/06 职场文书
学校少先队工作总结
2015/08/12 职场文书
《语言的突破》读后感3篇
2019/12/12 职场文书
MySQL悲观锁与乐观锁的实现方案
2021/11/02 MySQL