PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的迭代器漫谈
Feb 03 Python
关于反爬虫的一些简单总结
Dec 13 Python
Python遍历某目录下的所有文件夹与文件路径
Mar 15 Python
Python 查找字符在字符串中的位置实例
May 02 Python
浅谈python3.6的tkinter运行问题
Feb 22 Python
记一次pyinstaller打包pygame项目为exe的过程(带图片)
Mar 02 Python
python 实现图像快速替换某种颜色
Jun 04 Python
基于K.image_data_format() == 'channels_first' 的理解
Jun 29 Python
Django视图、传参和forms验证操作
Jul 15 Python
Python实现文本文件拆分写入到多个文本文件的方法
Apr 18 Python
Python基于百度AI实现抓取表情包
Jun 27 Python
详解MindSpore自定义模型损失函数
Jun 30 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
php初始化对象和析构函数的简单实例
2014/03/11 PHP
Yii中CGridView关联表搜索排序方法实例详解
2014/12/03 PHP
php中Snoopy类用法实例
2015/06/19 PHP
php文件上传类完整实例
2016/05/14 PHP
Linux下源码包安装Swoole及基本使用操作图文详解
2019/04/02 PHP
用js计算页面执行时间的函数
2006/12/07 Javascript
Jquery之美中不足小结
2011/02/16 Javascript
JavaScript中函数声明优先于变量声明的实例分析
2012/03/01 Javascript
基于jQuery实现多层次的手风琴效果附源码
2015/09/21 Javascript
jquery zTree异步加载、模糊搜索简单实例分享
2016/03/24 Javascript
AngularJS实现自定义指令及指令配置项的方法
2017/11/20 Javascript
vue-cli项目中使用echarts图表实例
2018/10/22 Javascript
vue实现动态显示与隐藏底部导航的方法分析
2019/02/11 Javascript
jQuery实现滑动星星评分效果(每日分享)
2019/11/13 jQuery
Vue.js仿Select下拉框效果
2020/02/18 Javascript
python 中的列表解析和生成表达式
2011/03/10 Python
详解Python使用simplejson模块解析JSON的方法
2016/03/24 Python
浅谈Python 集合(set)类型的操作——并交差
2016/06/30 Python
python 图像处理画一个正弦函数代码实例
2019/09/10 Python
jupyter notebook运行命令显示[*](解决办法)
2020/05/18 Python
opencv 阈值分割的具体使用
2020/07/08 Python
Python实现敏感词过滤的4种方法
2020/09/12 Python
python查询MySQL将数据写入Excel
2020/10/29 Python
Python基于爬虫实现全网搜索并下载音乐
2021/02/14 Python
新奇的小玩意:IWOOT
2016/07/21 全球购物
Chain Reaction Cycles芬兰:世界上最大的在线自行车商店
2017/12/06 全球购物
HelloFresh奥地利:立即订购烹饪盒
2019/02/22 全球购物
西雅图的买手店:Totokaelo
2019/10/19 全球购物
人力资源管理专业学生自我评价
2013/11/20 职场文书
小学开学寄语
2014/01/19 职场文书
表彰大会策划方案
2014/05/13 职场文书
课外科技活动总结
2014/08/27 职场文书
2014年工商所工作总结
2014/12/09 职场文书
喋血孤城观后感
2015/06/08 职场文书
高中化学教学反思
2016/02/22 职场文书
JavaScript原型链中函数和对象的理解
2022/06/16 Javascript