PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现汉诺塔递归算法经典案例
Mar 01 Python
python 读写、创建 文件的方法(必看)
Sep 12 Python
windows下python安装paramiko模块和pycrypto模块(简单三步)
Jul 06 Python
解决python3中解压zip文件是文件名乱码的问题
Mar 22 Python
详解pyqt5 动画在QThread线程中无法运行问题
May 05 Python
python 从csv读数据到mysql的实例
Jun 21 Python
简单了解Django应用app及分布式路由
Jul 24 Python
Python实现发票自动校核微信机器人的方法
May 22 Python
Python新手如何进行闭包时绑定变量操作
May 29 Python
利用python+ffmpeg合并B站视频及格式转换的实例代码
Nov 24 Python
简单介绍Python的第三方库yaml
Jun 18 Python
pandas时间序列之pd.to_datetime()的实现
Jun 16 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
PHP新手上路(七)
2006/10/09 PHP
PHP基础教程(php入门基础教程)一些code代码
2013/01/06 PHP
php实现监控varnish缓存服务器的状态
2014/12/30 PHP
PHP利用APC模块实现大文件上传进度条的方法
2015/10/29 PHP
PHP判断手机是IOS还是Android
2015/12/09 PHP
Zend Framework前端控制器用法示例
2016/12/11 PHP
PHP中字符串长度的截取用法示例
2017/01/12 PHP
兼容IE和Firefox的javascript获取iframe文档内容的函数
2011/08/15 Javascript
文字不间断滚动(上下左右)实例代码
2013/04/21 Javascript
Jquery使用Firefox FireBug插件调试Ajax步骤讲解
2013/12/02 Javascript
让jQuery与其他JavaScript库并存避免冲突的方法
2013/12/23 Javascript
使用jquery 简单实现下拉菜单
2015/01/14 Javascript
jquery中添加属性和删除属性
2015/06/03 Javascript
jQuery中的一些小技巧
2017/01/18 Javascript
利用JS实现文字的聚合动画效果
2017/01/22 Javascript
防止重复发送 Ajax 请求
2017/02/15 Javascript
详解vue项目构建与实战
2017/06/27 Javascript
JavaScript之json_动力节点Java学院整理
2017/06/29 Javascript
通过一个简单的例子学会vuex与模块化
2017/11/22 Javascript
Vue+axios+WebApi+NPOI导出Excel文件实例方法
2019/06/05 Javascript
小程序Scroll-view上拉滚动刷新数据
2020/06/21 Javascript
vue 子组件修改data或调用操作
2020/08/07 Javascript
[01:57]2016完美“圣”典风云人物:国士无双专访
2016/12/04 DOTA
Python3.6简单操作Mysql数据库
2017/09/12 Python
pyqt5自定义信号实例解析
2018/01/31 Python
python爬虫中get和post方法介绍以及cookie作用
2018/02/08 Python
Python科学计算包numpy用法实例详解
2018/02/08 Python
Python批处理更改文件名os.rename的方法
2018/10/26 Python
德国亚洲食品网上商店:asiafoodland.de
2019/12/28 全球购物
"引用"与指针的区别是什么
2016/09/07 面试题
口腔工艺技术专业毕业生自荐信
2013/09/27 职场文书
聚美优品励志广告词
2014/03/14 职场文书
质量承诺书怎么写
2014/05/24 职场文书
中层干部培训方案
2014/06/16 职场文书
室内趣味活动方案
2014/08/24 职场文书
谢师宴学生致辞
2015/07/27 职场文书