PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之Import 模块
Oct 13 Python
python 调用HBase的简单实例
Dec 18 Python
python绘制直线的方法
Jun 30 Python
linux中如何使用python3获取ip地址
Jul 15 Python
python中类的输出或类的实例输出为这种形式的原因
Aug 12 Python
Python搭建HTTP服务过程图解
Dec 14 Python
python 读写文件包含多种编码格式的解决方式
Dec 20 Python
Python3变量与基本数据类型用法实例分析
Feb 14 Python
Python sqlite3查询操作过程解析
Feb 20 Python
python GUI库图形界面开发之PyQt5拖放控件实例详解
Feb 25 Python
Python爬虫后获取重定向url的两种方法
Jan 19 Python
浅谈Python协程asyncio
Jun 20 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
php中session_unset与session_destroy的区别分析
2011/06/16 PHP
PHP版国家代码、缩写查询函数代码
2011/08/14 PHP
深入探讨:Nginx 502 Bad Gateway错误的解决方法
2013/06/03 PHP
php过滤表单提交的html等危险代码
2014/11/03 PHP
PHP简单获取网站百度搜索和搜狗搜索收录量的方法
2016/08/23 PHP
PHP实现将优酷土豆腾讯视频html地址转换成flash swf地址的方法
2017/08/04 PHP
SeaJS入门教程系列之完整示例(三)
2014/03/03 Javascript
基于Jquery easyui 选中特定的tab
2015/11/17 Javascript
实现React单页应用的方法详解
2016/08/02 Javascript
利用JavaScript实现拖拽改变元素大小
2016/12/14 Javascript
浅谈React深度编程之受控组件与非受控组件
2017/12/26 Javascript
利用SpringMVC过滤器解决vue跨域请求的问题
2018/02/10 Javascript
WebPack配置vue多页面的技巧
2018/05/15 Javascript
vue+iview 实现可编辑表格的示例代码
2018/10/31 Javascript
微信小程序实现上传多张图片、删除图片
2020/07/29 Javascript
[07:03]显微镜下的DOTA2第九期——430圣堂刺客杀戮秀
2014/06/20 DOTA
[40:55]Liquid vs LGD 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
[01:10:57]Liquid vs OG 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
Python数据分析之真实IP请求Pandas详解
2016/11/18 Python
python机器学习之神经网络(三)
2017/12/20 Python
Python适配器模式代码实现解析
2019/08/02 Python
python中时间转换datetime和pd.to_datetime详析
2019/08/11 Python
Python Django 页面上展示固定的页码数实现代码
2019/08/21 Python
Django celery异步任务实现代码示例
2020/11/26 Python
python 实现socket服务端并发的四种方式
2020/12/14 Python
Canvas实现贝赛尔曲线轨迹动画的示例代码
2019/04/25 HTML / CSS
HTML5在手机端实现视频全屏展示方法
2020/11/23 HTML / CSS
英国珠宝钟表和家居礼品精品店:David Shuttle
2018/02/24 全球购物
机械专业个人求职自荐信格式
2013/09/21 职场文书
企业文化标语大全
2014/06/10 职场文书
亚运会口号
2014/06/20 职场文书
自愿离婚协议书范文2014
2014/10/12 职场文书
环保主题班会教案
2015/08/13 职场文书
漫画「古见同学有交流障碍症」第25卷封面公开
2022/03/21 日漫
python在package下继续嵌套一个package
2022/04/14 Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
2022/05/04 Python