PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量导出导入MySQL用户的方法
Nov 15 Python
Python 将RGB图像转换为Pytho灰度图像的实例
Nov 14 Python
Python返回数组/List长度的实例
Jun 23 Python
python dict 相同key 合并value的实例
Jan 21 Python
简单了解python的内存管理机制
Jul 08 Python
Python中的类与类型示例详解
Jul 10 Python
Django外键(ForeignKey)操作以及related_name的作用详解
Jul 29 Python
python编写俄罗斯方块
Mar 13 Python
基于Django signals 信号作用及用法详解
Mar 28 Python
python实现简单的tcp 文件下载
Sep 16 Python
Selenium结合BeautifulSoup4编写简单的python爬虫
Nov 06 Python
Python基于unittest实现测试用例执行
Nov 25 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
创建数据库php代码 用PHP写出自己的BLOG系统
2010/04/12 PHP
php数组函数序列之array_unique() - 去除数组中重复的元素值
2011/10/29 PHP
解析web文件操作常见安全漏洞(目录、文件名检测漏洞)
2013/06/29 PHP
重新认识php array_merge函数
2014/08/31 PHP
visual studio code 调试php方法(图文详解)
2017/09/15 PHP
Javascript 面向对象(三)接口代码
2012/05/23 Javascript
JS定时刷新页面及跳转页面的方法
2013/07/04 Javascript
Firefox和IE兼容性问题及解决方法总结
2013/10/08 Javascript
JavaScript子类用Object.getPrototypeOf去调用父类方法解析
2013/12/05 Javascript
JS删除字符串中重复字符方法
2014/03/09 Javascript
jquery uploadify 在FF下无效的解决办法
2014/09/26 Javascript
用简洁的jQuery方法toggleClass实现隔行换色
2014/10/22 Javascript
node.js中的fs.futimesSync方法使用说明
2014/12/17 Javascript
用window.onerror捕获并上报Js错误的方法
2016/01/27 Javascript
javascript Promise简单学习使用方法小结
2016/05/17 Javascript
BootStrap Validator使用注意事项(必看篇)
2016/09/28 Javascript
Vue.js 2.5新特性介绍(推荐)
2017/10/24 Javascript
Angular5.1新功能分享
2017/12/21 Javascript
vue如何截取字符串
2019/05/06 Javascript
vue实现搜索功能
2019/05/28 Javascript
vue控制多行文字展开收起的实现示例
2019/10/11 Javascript
OpenLayers3实现测量功能
2020/09/25 Javascript
python处理中文编码和判断编码示例
2014/02/26 Python
学习python之编写简单简单连接数据库并执行查询操作
2016/02/27 Python
教你用Type Hint提高Python程序开发效率
2016/08/08 Python
Python cookbook(数据结构与算法)将序列分解为单独变量的方法
2018/02/13 Python
python3 读写文件换行符的方法
2018/04/09 Python
python 将md5转为16字节的方法
2018/05/29 Python
Python中的集合介绍
2019/01/28 Python
调试Django时打印SQL语句的日志代码实例
2019/09/12 Python
简单了解pytest测试框架setup和tearDown
2020/04/14 Python
Python内置函数及功能简介汇总
2020/10/13 Python
德国高尔夫商店:Golfshop.de
2019/06/22 全球购物
什么是数据抽象
2016/11/26 面试题
介绍一下Linux中的链接
2016/06/05 面试题
家庭贫困证明
2015/06/16 职场文书