PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python contextlib模块使用示例
Feb 18 Python
Python的Django框架中forms表单类的使用方法详解
Jun 21 Python
python 网络编程详解及简单实例
Apr 25 Python
python对DICOM图像的读取方法详解
Jul 17 Python
Django学习教程之静态文件的调用详解
May 08 Python
Python带动态参数功能的sqlite工具类
May 26 Python
使用Python的OpenCV模块识别滑动验证码的缺口(推荐)
May 10 Python
sklearn-SVC实现与类参数详解
Dec 10 Python
Python通过yagmail实现发送邮件代码解析
Oct 27 Python
python小技巧——将变量保存在本地及读取
Nov 13 Python
django中ImageField的使用详解
Dec 21 Python
MATLAB 如何求取离散点的曲率最大值
Apr 16 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
PHP Cookie的使用教程详解
2013/06/03 PHP
PHP自定义大小验证码的方法详解
2013/06/07 PHP
php获取bing每日壁纸示例分享
2014/02/25 PHP
php更新mysql后获取改变行数的方法
2014/12/25 PHP
php-msf源码详解
2017/12/25 PHP
jquery对表单操作2
2011/04/06 Javascript
javascript中获取下个月一号,是星期几
2012/06/01 Javascript
增强用户体验友好性之jquery easyui window 窗口关闭时的提示
2012/06/22 Javascript
jquery 通过name快速取值示例
2014/01/24 Javascript
使用jquery组件qrcode生成二维码及应用指南
2015/02/22 Javascript
ES6中的数组扩展方法
2016/08/26 Javascript
jQuery右下角悬浮广告实例
2016/10/17 Javascript
JS插件plupload.js实现多图上传并显示进度条
2016/11/29 Javascript
Node.js websocket使用socket.io库实现实时聊天室
2017/02/20 Javascript
canvas基础绘制-绚丽倒计时的实例
2017/09/17 Javascript
jQuery实现的滑块滑动导航效果示例
2018/06/04 jQuery
详解Vue SPA项目优化小记
2018/07/03 Javascript
vue中组件的过渡动画及实现代码
2018/11/21 Javascript
javascript Canvas动态粒子连线
2020/01/01 Javascript
python 集合 并集、交集 Series list set 转换的实例
2018/05/29 Python
django框架自定义用户表操作示例
2018/08/07 Python
使用Python 统计高频字数的方法
2019/01/31 Python
Python实现的栈、队列、文件目录遍历操作示例
2019/05/06 Python
python打印9宫格、25宫格等奇数格 满足横竖斜相加和相等
2019/07/19 Python
python 自动识别并连接串口的实现
2021/01/19 Python
使用CSS3的font-face字体嵌入样式的方法讲解
2016/05/13 HTML / CSS
Skyscanner英国:苏格兰的全球三大领先航班搜索服务之一
2017/11/09 全球购物
护理个人求职信范文
2014/01/08 职场文书
初中物理教学反思
2014/01/14 职场文书
工厂仓管员岗位职责范本
2014/07/17 职场文书
法人授权委托书
2014/09/16 职场文书
死亡赔偿协议书
2015/01/28 职场文书
2015年世界无烟日活动方案
2015/05/04 职场文书
公司安全管理制度范本
2015/08/05 职场文书
学习计划是什么
2019/04/30 职场文书
Java十分钟精通进阶适配器模式
2022/04/06 Java/Android