PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
php使用递归与迭代实现快速排序示例
Jan 23 Python
Python使用os模块和fileinput模块来操作文件目录
Jan 19 Python
深入理解NumPy简明教程---数组3(组合)
Dec 17 Python
基于ID3决策树算法的实现(Python版)
May 31 Python
python统计字母、空格、数字等字符个数的实例
Jun 29 Python
Python二进制串转换为通用字符串的方法
Jul 23 Python
pytorch 调整某一维度数据顺序的方法
Dec 08 Python
Python简单获取二维数组行列数的方法示例
Dec 21 Python
解决pycharm 远程调试 上传 helpers 卡住的问题
Jun 27 Python
解决Python3 抓取微信账单信息问题
Jul 19 Python
工程师必须了解的LRU缓存淘汰算法以及python实现过程
Oct 15 Python
python中pow函数用法及功能说明
Dec 04 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
php下使用以下代码连接并测试
2008/04/09 PHP
php把数据表导出为Excel表的最简单、最快的方法(不用插件)
2014/05/10 PHP
利用PHP fsockopen 模拟POST/GET传送数据的方法
2015/09/22 PHP
ThinkPHP框架实现的邮箱激活功能示例
2018/06/15 PHP
jquery中实现简单的tabs插件功能的代码
2011/03/02 Javascript
精心挑选的15个jQuery下拉菜单制作教程
2012/06/15 Javascript
js图片闪动特效可以控制间隔时间如几分钟闪动一下
2014/08/12 Javascript
关闭页面时window.location事件未执行的原因分析及解决方案
2014/09/01 Javascript
JS/jQuery判断DOM节点是否存在的简单方法
2016/11/24 Javascript
使用vue框架 Ajax获取数据列表并用BootStrap显示出来
2017/04/24 Javascript
mac上node.js环境的安装测试
2017/07/03 Javascript
详解用vue编写弹出框组件
2017/07/04 Javascript
详解express与koa中间件模式对比
2017/08/07 Javascript
JS操作时间 - UNIX时间戳的简单介绍(必看篇)
2017/08/16 Javascript
vue-resouce设置请求头的三种方法
2017/09/12 Javascript
薪资那么高的Web前端必看书单
2017/10/13 Javascript
原生JS实现获取及修改CSS样式的方法
2018/09/04 Javascript
封装微信小程序http拦截器过程解析
2019/08/13 Javascript
利用webpack理解CommonJS和ES Modules的差异区别
2020/06/16 Javascript
利用H5api实现时钟的绘制(javascript)
2020/09/13 Javascript
[06:07]DOTA2-DPC中国联赛 正赛 Ehome vs VG 选手采访
2021/03/11 DOTA
python脚本设置系统时间的两种方法
2016/02/21 Python
matplotlib实现区域颜色填充
2019/03/18 Python
Python如何调用JS文件中的函数
2019/08/16 Python
python的列表List求均值和中位数实例
2020/03/03 Python
使用Pyhton 分析酒店针孔摄像头
2020/03/04 Python
python实现udp传输图片功能
2020/03/20 Python
苹果音乐订阅:Apple Music
2018/08/02 全球购物
extern在函数声明中是什么意思
2014/01/19 面试题
Ruby如何进行文件操作
2014/07/17 面试题
中学生打架检讨书
2014/02/10 职场文书
党支部三会一课计划
2014/09/24 职场文书
行政人事专员岗位职责
2015/04/07 职场文书
2015年机关党建工作总结
2015/05/22 职场文书
2016年元旦主持词
2015/07/06 职场文书
英文诗歌翻译方法(赏析)
2019/08/16 职场文书