pytorch中Schedule与warmup_steps的用法说明


Posted in Python onMay 24, 2021

1. lr_scheduler相关

lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)

其中args.warmup_steps可以认为是耐心系数

num_train_optimization_steps为模型参数的总更新次数

一般来说:

num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)

Schedule用来调节学习率,拿线性变换调整来说,下面代码中,step是当前迭代次数。

def lr_lambda(self, step):
        # 线性变换,返回的是某个数值x,然后返回到类LambdaLR中,最终返回old_lr*x
        if step < self.warmup_steps: # 增大学习率
            return float(step) / float(max(1, self.warmup_steps))
        # 减小学习率
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))

在实际运行中,lr_scheduler.step()先将lr初始化为0. 在第一次参数更新时,此时step=1,lr由0变为初始值initial_lr;在第二次更新时,step=2,上面代码中生成某个实数alpha,新的lr=initial_lr *alpha;在第三次更新时,新的lr是在initial_lr基础上生成,即新的lr=initial_lr *alpha。

其中warmup_steps可以认为是lr调整的耐心系数。

由于有warmup_steps存在,lr先慢慢增加,超过warmup_steps时,lr再慢慢减小。

在实际中,由于训练刚开始时,训练数据计算出的grad可能与期望方向相反,所以此时采用较小的lr,随着迭代次数增加,lr线性增大,增长率为1/warmup_steps;迭代次数等于warmup_steps时,学习率为初始设定的学习率;迭代次数超过warmup_steps时,学习率逐步衰减,衰减率为1/(total-warmup_steps),再进行微调。

2. gradient_accumulation_steps相关

gradient_accumulation_steps通过累计梯度来解决本地显存不足问题。

假设原来的batch_size=6,样本总量为24,gradient_accumulation_steps=2

那么参数更新次数=24/6=4

现在,减小batch_size=6/2=3,参数更新次数不变=24/3/2=4

在梯度反传时,每gradient_accumulation_steps次进行一次梯度更新,之前照常利用loss.backward()计算梯度。

补充:pytorch学习笔记 -optimizer.step()和scheduler.step()

optimizer.step()和scheduler.step()的区别

optimizer.step()通常用在每个mini-batch之中,而scheduler.step()通常用在epoch里面,但是不绝对,可以根据具体的需求来做。只有用了optimizer.step(),模型才会更新,而scheduler.step()是对lr进行调整。

通常我们有

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)

在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。

所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python提示No module named images的解决方法
Sep 29 Python
跟老齐学Python之编写类之二方法
Oct 11 Python
Python中基础的socket编程实战攻略
Jun 01 Python
python爬虫实现教程转换成 PDF 电子书
Feb 19 Python
通过Python 获取Android设备信息的轻量级框架
Dec 18 Python
使用python验证代理ip是否可用的实现方法
Jul 25 Python
详解python之heapq模块及排序操作
Apr 04 Python
python中栈的原理及实现方法示例
Nov 27 Python
Python selenium页面加载慢超时的解决方案
Mar 18 Python
Anaconda的安装与虚拟环境建立
Nov 18 Python
python解包概念及实例
Feb 17 Python
Python创建自己的加密货币的示例
Mar 01 Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
You might like
为了这两部电子管收音机,买了6套全新电子管和10粒刻度盘灯泡
2021/03/02 无线电
不用GD库生成当前时间的PNG格式图象的程序
2006/10/09 PHP
PHP运用foreach神奇的转换数组(实例讲解)
2018/02/01 PHP
php面向对象程序设计入门教程
2019/06/22 PHP
Laravel 手动开关 Eloquent 修改器的操作方法
2019/12/30 PHP
jQuery 过滤not()与filter()实例代码
2012/05/10 Javascript
javascript改变position值实现菜单滚动至顶部后固定
2013/01/18 Javascript
简介JavaScript中Math.cos()余弦方法的使用
2015/06/15 Javascript
如何高效率去掉js数组中的重复项
2016/04/12 Javascript
微信公众号-获取用户信息(网页授权获取)实现步骤
2016/10/21 Javascript
Bootstrap 3.x打印预览背景色与文字显示异常的解决
2016/11/06 Javascript
vue.js+Element实现表格里的增删改查
2017/01/18 Javascript
JS实现图片预加载之无序预加载功能代码
2017/05/12 Javascript
node.js中express中间件body-parser的介绍与用法详解
2017/05/23 Javascript
使用Vue动态生成form表单的实例代码
2018/04/26 Javascript
Vue实现商品详情页的评价列表功能
2019/09/04 Javascript
vue子传父关于.sync与$emit的实现
2019/11/05 Javascript
js实现盒子拖拽动画效果
2020/08/09 Javascript
python中List的sort方法指南
2014/09/01 Python
Python中datetime常用时间处理方法
2015/06/15 Python
Python多层嵌套list的递归处理方法(推荐)
2016/06/08 Python
Python tkinter事件高级用法实例
2018/01/31 Python
CentOS 7下安装Python3.6 及遇到的问题小结
2018/11/08 Python
pyshp创建shp点文件的方法
2018/12/31 Python
Python向excel中写入数据的方法
2019/05/05 Python
Python for i in range ()用法详解
2020/09/18 Python
Steiff台湾官网:德国金耳釦泰迪熊
2019/12/26 全球购物
捷克多品牌在线时尚商店:ANSWEAR.cz
2020/10/03 全球购物
挖掘机司机岗位职责
2014/02/12 职场文书
2014年三万活动总结
2014/04/26 职场文书
婚前保证书
2014/04/29 职场文书
Python实现机器学习算法的分类
2021/06/03 Python
Java SSM配置文件案例详解
2021/08/30 Java/Android
简单聊聊Vue中的计算属性和属性侦听
2021/10/05 Vue.js
Windows Server 2019 安装DHCP服务及相关配置
2022/04/28 Servers
Win11 21h2可以升级22h2吗?看看你的电脑符不符合要求
2022/07/07 数码科技