PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python编写脚本获取手机当前应用apk的信息
Jul 21 Python
Python实现删除文件但保留指定文件
Jun 21 Python
python实现根据指定字符截取对应的行的内容方法
Oct 23 Python
详解Python安装tesserocr遇到的各种问题及解决办法
Mar 07 Python
Django实现发送邮件功能
Jul 18 Python
python requests库爬取豆瓣电视剧数据并保存到本地详解
Aug 10 Python
调用其他python脚本文件里面的类和方法过程解析
Nov 15 Python
tensorflow生成多个tfrecord文件实例
Feb 17 Python
python3注册全局热键的实现
Mar 22 Python
Django Session和Cookie分别实现记住用户登录状态操作
Jul 02 Python
python 输入字符串生成所有有效的IP地址(LeetCode 93号题)
Oct 15 Python
python抢购软件/插件/脚本附完整源码
Mar 04 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
PHP的FTP学习(二)[转自奥索]
2006/10/09 PHP
GD输出汉字的函数的分析
2006/10/09 PHP
PHP中遇到的时区问题解决方法
2015/07/23 PHP
PHP的Yii框架中移除组件所绑定的行为的方法
2016/03/18 PHP
PHP实现无限极分类的两种方式示例【递归和引用方式】
2019/03/25 PHP
php输出形式实例整理
2020/05/05 PHP
基于JQuery的模拟苹果桌面Dock效果(稳定版)
2012/10/15 Javascript
jQuery中RadioButtonList的功能及用法实例介绍
2013/08/23 Javascript
javascript实现鼠标拖动改变层大小的方法
2015/04/30 Javascript
JS实现弹性漂浮效果的广告代码
2015/09/02 Javascript
jQuery-1.9.1源码分析系列(十一)DOM操作续之克隆节点
2015/12/01 Javascript
JavaScript实现实时更新系统时间的实例代码
2017/04/04 Javascript
在vue.js中抽出公共代码的方法示例
2017/06/08 Javascript
JS获取鼠标坐标并且根据鼠标位置不同弹出不同内容
2017/06/12 Javascript
JS非行间样式获取函数的实例代码
2018/06/05 Javascript
微信小程序踩坑记录之解决tabBar.list[3].selectedIconPath大小超过40kb
2018/07/04 Javascript
详解小程序rich-text对富文本支持方案
2018/11/28 Javascript
js使用swiper实现层叠轮播效果实例代码
2018/12/12 Javascript
微信小程序 行的删除和增加操作实现详解
2019/09/29 Javascript
python实现360的字符显示界面
2014/02/21 Python
python使用itchat库实现微信机器人(好友聊天、群聊天)
2018/01/04 Python
深入了解Python中pop和remove的使用方法
2018/01/09 Python
python 动态调用函数实例解析
2019/10/21 Python
python 实现生成均匀分布的点
2019/12/05 Python
python 多线程爬取壁纸网站的示例
2021/02/20 Python
canvas之万花筒效果的简单实现(推荐)
2016/08/16 HTML / CSS
大学应届生求职简历的自我评价
2013/10/08 职场文书
学前教育毕业生自荐信
2013/10/29 职场文书
《找不到快乐的波斯猫》教学反思
2014/02/24 职场文书
优秀三好学生事迹材料
2014/08/31 职场文书
乡镇党的群众路线教育实践活动领导班子对照检查材料
2014/09/25 职场文书
2014群众路线学习笔记
2014/11/06 职场文书
2014年安全保卫工作总结
2014/11/13 职场文书
2014年小学少先队工作总结
2014/12/18 职场文书
一年级小学生评语大全
2014/12/25 职场文书
MySQL Innodb关键特性之插入缓冲(insert buffer)
2021/04/08 MySQL