PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python并发获取snmp信息及性能测试
Mar 27 Python
解决Django模板无法使用perms变量问题的方法
Sep 10 Python
在PyCharm中三步完成PyPy解释器的配置的方法
Oct 29 Python
对Python+opencv将图片生成视频的实例详解
Jan 08 Python
python快排算法详解
Mar 04 Python
详解python执行shell脚本创建用户及相关操作
Apr 11 Python
python占位符输入方式实例
May 27 Python
python3爬取torrent种子链接实例
Jan 16 Python
使用Python脚本从文件读取数据代码实例
Jan 19 Python
信号生成及DFT的python实现方式
Feb 25 Python
python GUI库图形界面开发之PyQt5计数器控件QSpinBox详细使用方法与实例
Feb 28 Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
防止MySQL注入或HTML表单滥用的PHP程序
2009/01/21 PHP
Yii使用find findAll查找出指定字段的实现方法
2014/09/05 PHP
在Linux系统下一键重新安装WordPress的脚本示例
2015/06/30 PHP
ThinkPHP函数详解之M方法和R方法
2015/09/10 PHP
php文件上传及下载附带显示文件及目录功能
2017/04/27 PHP
PHP实现从上往下打印二叉树的方法
2018/01/18 PHP
PHP 计算两个特别大的整数实例代码
2018/05/07 PHP
js数组的操作详解
2013/03/27 Javascript
查看图片(前进后退)功能实现js代码
2013/04/24 Javascript
文字垂直滚动之javascript代码
2015/07/29 Javascript
jquery+css实现绚丽的横向二级下拉菜单-附源码下载
2015/08/23 Javascript
JS验证邮件地址格式方法小结
2015/12/01 Javascript
深入理解JavaScript中的预解析
2017/01/04 Javascript
深入理解Angular4中的依赖注入
2017/06/07 Javascript
validationEngine 表单验证插件使用实例代码
2017/06/15 Javascript
详解Vue双向数据绑定原理解析
2017/09/11 Javascript
webpack+vue2构建vue项目骨架的方法
2018/01/09 Javascript
ES6 对象的新功能与解构赋值介绍
2019/02/05 Javascript
react-intl实现React国际化多语言的方法
2020/09/27 Javascript
[03:24]DOTA2超级联赛专访hao 大翻盘就是逆袭
2013/05/24 DOTA
python和pyqt实现360的CLable控件
2014/02/21 Python
Python面向对象编程中的类和对象学习教程
2015/03/30 Python
Python编程实现二叉树及七种遍历方法详解
2017/06/02 Python
[原创]Python入门教程5. 字典基本操作【定义、运算、常用函数】
2018/11/01 Python
用Pycharm实现鼠标滚轮控制字体大小的方法
2019/01/15 Python
python使用time、datetime返回工作日列表实例代码
2019/05/09 Python
Python爬虫 批量爬取下载抖音视频代码实例
2019/08/16 Python
django序列化时使用外键的真实值操作
2020/07/15 Python
Matplotlib配色之Colormap详解
2021/01/05 Python
同学聚会欢迎辞
2014/01/14 职场文书
学校万圣节活动方案
2014/02/13 职场文书
学校端午节活动方案
2014/08/23 职场文书
正风肃纪剖析材料范文
2014/10/10 职场文书
Python+Appium实现自动抢微信红包
2021/05/21 Python
浅谈MySQL next-key lock 加锁范围
2021/06/07 MySQL
python绘制云雨图raincloud plot
2022/08/05 Python