pytorch中Schedule与warmup_steps的用法说明


Posted in Python onMay 24, 2021

1. lr_scheduler相关

lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)

其中args.warmup_steps可以认为是耐心系数

num_train_optimization_steps为模型参数的总更新次数

一般来说:

num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)

Schedule用来调节学习率,拿线性变换调整来说,下面代码中,step是当前迭代次数。

def lr_lambda(self, step):
        # 线性变换,返回的是某个数值x,然后返回到类LambdaLR中,最终返回old_lr*x
        if step < self.warmup_steps: # 增大学习率
            return float(step) / float(max(1, self.warmup_steps))
        # 减小学习率
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))

在实际运行中,lr_scheduler.step()先将lr初始化为0. 在第一次参数更新时,此时step=1,lr由0变为初始值initial_lr;在第二次更新时,step=2,上面代码中生成某个实数alpha,新的lr=initial_lr *alpha;在第三次更新时,新的lr是在initial_lr基础上生成,即新的lr=initial_lr *alpha。

其中warmup_steps可以认为是lr调整的耐心系数。

由于有warmup_steps存在,lr先慢慢增加,超过warmup_steps时,lr再慢慢减小。

在实际中,由于训练刚开始时,训练数据计算出的grad可能与期望方向相反,所以此时采用较小的lr,随着迭代次数增加,lr线性增大,增长率为1/warmup_steps;迭代次数等于warmup_steps时,学习率为初始设定的学习率;迭代次数超过warmup_steps时,学习率逐步衰减,衰减率为1/(total-warmup_steps),再进行微调。

2. gradient_accumulation_steps相关

gradient_accumulation_steps通过累计梯度来解决本地显存不足问题。

假设原来的batch_size=6,样本总量为24,gradient_accumulation_steps=2

那么参数更新次数=24/6=4

现在,减小batch_size=6/2=3,参数更新次数不变=24/3/2=4

在梯度反传时,每gradient_accumulation_steps次进行一次梯度更新,之前照常利用loss.backward()计算梯度。

补充:pytorch学习笔记 -optimizer.step()和scheduler.step()

optimizer.step()和scheduler.step()的区别

optimizer.step()通常用在每个mini-batch之中,而scheduler.step()通常用在epoch里面,但是不绝对,可以根据具体的需求来做。只有用了optimizer.step(),模型才会更新,而scheduler.step()是对lr进行调整。

通常我们有

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)

在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。

所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Windows中使用wxPython和py2exe开发Python的GUI程序的实例教程
Jul 11 Python
Python中django学习心得
Dec 06 Python
Python使用matplotlib实现绘制自定义图形功能示例
Jan 18 Python
tensorflow训练中出现nan问题的解决
Feb 10 Python
Python自动发送邮件的方法实例总结
Dec 08 Python
Python实现E-Mail收集插件实例教程
Feb 06 Python
使用Python检测文章抄袭及去重算法原理解析
Jun 14 Python
python和c语言的主要区别总结
Jul 07 Python
Python全栈之列表数据类型详解
Oct 01 Python
浅谈python之自动化运维(Paramiko)
Jan 31 Python
Pycharm Git 设置方法
Sep 15 Python
pycharm配置安装autopep8自动规范代码的实现
Mar 02 Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
You might like
jQuery 瀑布流 绝对定位布局(二)(延迟AJAX加载图片)
2012/05/23 Javascript
JQuery的Ajax请求实现局部刷新的简单实例
2014/02/11 Javascript
jQuery实现html元素拖拽
2015/07/21 Javascript
实例代码讲解jquery easyui动态tab页
2015/11/17 Javascript
Extjs实现下拉菜单效果
2016/04/01 Javascript
Bootstrap安装环境配置教程分享
2016/05/27 Javascript
Bootstrap导航条可点击和鼠标悬停显示下拉菜单的实现代码
2016/06/23 Javascript
AngularJS控制器controller给模型数据赋初始值的方法
2017/01/04 Javascript
Node.js 中exports 和 module.exports 的区别
2017/03/14 Javascript
AngularJS实现页面跳转后自动弹出对话框实例代码
2017/08/02 Javascript
封装运动框架实战左右与上下滑动的焦点轮播图(实例)
2017/10/17 Javascript
利用vue + element实现表格分页和前端搜索的方法
2017/12/25 Javascript
javascript canvas封装动态时钟
2020/09/30 Javascript
element中Steps步骤条和Tabs标签页关联的解决
2020/12/08 Javascript
python回溯法实现数组全排列输出实例分析
2015/03/17 Python
Python json 错误xx is not JSON serializable解决办法
2017/03/15 Python
浅谈python中的占位符
2017/11/09 Python
Python中的pygal安装和绘制直方图代码分享
2017/12/08 Python
Python 将pdf转成图片的方法
2018/04/23 Python
python实现给微信指定好友定时发送消息
2019/04/29 Python
Python3 执行Linux Bash命令的方法
2019/07/12 Python
解决pandas展示数据输出时列名不能对齐的问题
2019/11/18 Python
python使用pygame实现笑脸乒乓球弹珠球游戏
2019/11/25 Python
Python 实现向word(docx)中输出
2020/02/13 Python
python 常见的反爬虫策略
2020/09/27 Python
Python confluent kafka客户端配置kerberos认证流程详解
2020/10/12 Python
python推导式的使用方法实例
2021/02/28 Python
使用CSS3制作版头动画效果
2020/12/24 HTML / CSS
凯特方迪化妆品官网:Kat Von D Beauty
2016/11/15 全球购物
总裁岗位职责
2013/12/04 职场文书
财务内勤岗位职责
2014/04/17 职场文书
学校领导四风问题整改措施思想汇报
2014/10/09 职场文书
2015年党员个人工作总结
2015/05/13 职场文书
升学宴家长致辞
2015/07/27 职场文书
Python进阶学习之带你探寻Python类的鼻祖-元类
2021/05/08 Python
Spring Cloud Gateway去掉url前缀
2021/07/15 Java/Android