pytorch中Schedule与warmup_steps的用法说明


Posted in Python onMay 24, 2021

1. lr_scheduler相关

lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)

其中args.warmup_steps可以认为是耐心系数

num_train_optimization_steps为模型参数的总更新次数

一般来说:

num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)

Schedule用来调节学习率,拿线性变换调整来说,下面代码中,step是当前迭代次数。

def lr_lambda(self, step):
        # 线性变换,返回的是某个数值x,然后返回到类LambdaLR中,最终返回old_lr*x
        if step < self.warmup_steps: # 增大学习率
            return float(step) / float(max(1, self.warmup_steps))
        # 减小学习率
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))

在实际运行中,lr_scheduler.step()先将lr初始化为0. 在第一次参数更新时,此时step=1,lr由0变为初始值initial_lr;在第二次更新时,step=2,上面代码中生成某个实数alpha,新的lr=initial_lr *alpha;在第三次更新时,新的lr是在initial_lr基础上生成,即新的lr=initial_lr *alpha。

其中warmup_steps可以认为是lr调整的耐心系数。

由于有warmup_steps存在,lr先慢慢增加,超过warmup_steps时,lr再慢慢减小。

在实际中,由于训练刚开始时,训练数据计算出的grad可能与期望方向相反,所以此时采用较小的lr,随着迭代次数增加,lr线性增大,增长率为1/warmup_steps;迭代次数等于warmup_steps时,学习率为初始设定的学习率;迭代次数超过warmup_steps时,学习率逐步衰减,衰减率为1/(total-warmup_steps),再进行微调。

2. gradient_accumulation_steps相关

gradient_accumulation_steps通过累计梯度来解决本地显存不足问题。

假设原来的batch_size=6,样本总量为24,gradient_accumulation_steps=2

那么参数更新次数=24/6=4

现在,减小batch_size=6/2=3,参数更新次数不变=24/3/2=4

在梯度反传时,每gradient_accumulation_steps次进行一次梯度更新,之前照常利用loss.backward()计算梯度。

补充:pytorch学习笔记 -optimizer.step()和scheduler.step()

optimizer.step()和scheduler.step()的区别

optimizer.step()通常用在每个mini-batch之中,而scheduler.step()通常用在epoch里面,但是不绝对,可以根据具体的需求来做。只有用了optimizer.step(),模型才会更新,而scheduler.step()是对lr进行调整。

通常我们有

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)

在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。

所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用成员运算符的示例
May 13 Python
Python排序搜索基本算法之冒泡排序实例分析
Dec 09 Python
Python数据结构与算法之图的广度优先与深度优先搜索算法示例
Dec 14 Python
Python3 安装PyQt5及exe打包图文教程
Jan 08 Python
python 对字典按照value进行排序的方法
May 09 Python
python matplotlib.pyplot.plot()参数用法
Apr 14 Python
解决Jupyter NoteBook输出的图表太小看不清问题
Apr 16 Python
什么是Python中的顺序表
Jun 02 Python
python爬虫使用正则爬取网站的实现
Aug 03 Python
python 实现简单的计算器(gui界面)
Nov 11 Python
Pytorch如何切换 cpu和gpu的使用详解
Mar 01 Python
Python 线程池模块之多线程操作代码
May 20 Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
You might like
[原创]PHP获取数组表示的路径方法分析【数组转字符串】
2017/09/01 PHP
PHP数字金额转换成中文大写显示
2019/01/05 PHP
jquery 上下滚动广告
2009/06/17 Javascript
测试JavaScript字符串处理性能的代码
2009/12/07 Javascript
使用documentElement正确取得当前可见区域的大小
2014/07/25 Javascript
从JQuery源码分析JavaScript函数的apply方法与call方法
2014/09/25 Javascript
JS实现转动随机数抽奖特效代码
2020/04/16 Javascript
jQuery form插件之ajaxForm()和ajaxSubmit()的可选参数项对象
2016/01/23 Javascript
基于jQuery实现二级下拉菜单效果
2016/02/01 Javascript
javascript+HTML5自定义元素播放焦点图动画
2016/02/21 Javascript
纯css下拉菜单 无需js
2016/08/15 Javascript
用headjs来管理和加载js 提高网站加载速度
2016/11/29 Javascript
jQuery中table数据的值拷贝和拆分
2017/03/19 Javascript
Vue中计算属性computed的示例解读
2017/07/26 Javascript
详解为element-ui的Select和Cascader添加弹层底部操作按钮
2020/02/07 Javascript
JS异步宏队列微队列原理详解
2020/09/09 Javascript
[02:49]DAC2018决赛日TOP5 LGD开启黑暗之门绝杀VP
2018/04/08 DOTA
python实现将文本转换成语音的方法
2015/05/28 Python
Python获取央视节目单的实现代码
2015/07/25 Python
Python操作MySQL数据库的三种方法总结
2018/01/30 Python
Python简单实现网页内容抓取功能示例
2018/06/07 Python
Python DataFrame.groupby()聚合函数,分组级运算
2018/09/18 Python
使用python实现http及ftp服务进行数据传输的方法
2018/10/26 Python
Python魔术方法专题
2020/06/19 Python
pycharm配置QtDesigner的超详细方法
2021/01/25 Python
详解Java中一维、二维数组在内存中的结构
2021/02/11 Python
css3实现的多级渐变下拉菜单导航效果代码
2015/08/31 HTML / CSS
名词解释WEB SERVICE,SOAP,UDDI,WSDL,JAXP,JAXM;JSWDL开发包的介绍。
2012/10/27 面试题
介绍一下.net和Java的特点和区别
2012/09/26 面试题
广告学专业自荐信范文
2014/02/24 职场文书
2015年社区民政工作总结
2015/04/21 职场文书
员工保密协议范本,您一定得收藏!很有用!
2019/08/08 职场文书
Python WSGI 规范简介
2021/04/11 Python
CSS3实现的水平标题菜单
2021/04/14 HTML / CSS
详细谈谈MYSQL中的COLLATE是什么
2021/06/11 MySQL
德生BCL3000抢先使用感受和评价
2022/04/07 无线电