PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 获取本机ip地址的两个方法
Feb 25 Python
Python入门篇之正则表达式
Oct 20 Python
Python入门篇之面向对象
Oct 20 Python
Python实现方便使用的级联进度信息实例
May 05 Python
在Python中marshal对象序列化的相关知识
Jul 01 Python
Linux系统上Nginx+Python的web.py与Django框架环境
Dec 25 Python
Python读写Json涉及到中文的处理方法
Sep 12 Python
使用11行Python代码盗取了室友的U盘内容
Oct 23 Python
Pycharm如何运行.py文件的方法步骤
Mar 03 Python
next在python中返回迭代器的实例方法
Dec 15 Python
删除pycharm鼠标右键快捷键打开项目的操作
Jan 16 Python
聊聊Python中关于a=[[]]*3的反思
Jun 02 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
PHP4中session登录页面的应用
2008/07/25 PHP
php引用返回与取消引用的详解
2013/06/08 PHP
php批量删除数据库下指定前缀的表以prefix_为例
2014/08/24 PHP
自编函数解决pathinfo()函数处理中文问题
2014/11/03 PHP
一个经典的PHP文件上传类分享
2014/11/18 PHP
PHP中绘制图像的一些函数总结
2014/11/19 PHP
ThinkPHP3.2.3数据库设置新特性
2015/03/05 PHP
PHPMailer使用QQ邮箱实现邮件发送功能
2017/08/18 PHP
基于laravel where的高级使用方法
2019/10/10 PHP
ExtJS下grid的一些属性说明
2009/12/13 Javascript
jQuery powerFloat万能浮动层下拉层插件使用介绍
2010/12/27 Javascript
文本框根据输入内容自适应高度的代码
2011/10/24 Javascript
z-blog SyntaxHighlighter 长代码无法换行解决办法(jquery)
2014/11/16 Javascript
JavaScript将XML转成JSON的方法
2015/03/12 Javascript
jQuery制作效果超棒的手风琴折叠菜单
2015/04/03 Javascript
JS+CSS实现下拉列表框美化效果(3款)
2015/08/15 Javascript
AngularJS中处理多个promise的方式
2016/02/02 Javascript
artDialog+plupload实现多文件上传
2016/07/19 Javascript
Javascript中字符串replace方法的第二个参数探究
2016/12/05 Javascript
js中apply与call简单用法详解
2017/11/06 Javascript
谈谈vue中mixin的一点理解
2017/12/12 Javascript
使用async、enterproxy控制并发数量的方法详解
2018/01/02 Javascript
vue控制多行文字展开收起的实现示例
2019/10/11 Javascript
基于redis的小程序登录实现方法流程分析
2020/05/25 Javascript
解决vue项目获取dom元素宽高总是不准确问题
2020/07/29 Javascript
python交互式图形编程实例(二)
2017/11/17 Python
python求最大连续子数组的和
2018/07/07 Python
python手写均值滤波
2020/02/19 Python
python批量修改交换机密码的示例
2020/09/22 Python
python 调用Google翻译接口的方法
2020/12/09 Python
纯CSS3实现绘制各种图形实现代码详细整理
2012/12/26 HTML / CSS
简短大学毕业感言
2014/01/18 职场文书
体育比赛口号
2014/06/09 职场文书
一个都不能少观后感
2015/06/04 职场文书
小学记事作文之200字
2019/08/06 职场文书
Python基础之赋值,浅拷贝,深拷贝的区别
2021/04/30 Python