PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
初步解析Python下的多进程编程
Apr 28 Python
Python列表删除的三种方法代码分享
Oct 31 Python
TensorFlow深度学习之卷积神经网络CNN
Mar 09 Python
在dataframe两列日期相减并且得到具体的月数实例
Jul 03 Python
python自动发微信监控报警
Sep 06 Python
Python操作列表常用方法实例小结【创建、遍历、统计、切片等】
Oct 25 Python
python实现批量文件重命名
Oct 31 Python
Python实现变声器功能(萝莉音御姐音)
Dec 05 Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 Python
详解使用Python写一个向数据库填充数据的小工具(推荐)
Sep 11 Python
python实现简单的学生管理系统
Feb 22 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
php使浏览器直接下载pdf文件的方法
2013/11/15 PHP
浅谈laravel数据库查询返回的数据形式
2019/10/21 PHP
利用404错误页面实现UrlRewrite的实现代码
2008/08/20 Javascript
javascript 打印内容方法小结
2009/11/04 Javascript
用js实现判断当前网址的来路如果不是指定的来路就跳转到指定页面
2011/05/02 Javascript
拥抱模块化的JavaScript
2012/03/07 Javascript
JS 实现获取打开一个界面中输入的值
2013/03/19 Javascript
alert出数组中的随即值代码
2014/09/25 Javascript
最全的JavaScript开发工具列表 总有一款适合你
2017/06/29 Javascript
vue服务端渲染的实例代码
2017/08/28 Javascript
webpack4+express+mongodb+vue实现增删改查的示例
2018/11/08 Javascript
vue-cli3全面配置详解
2018/11/14 Javascript
一些你可能不熟悉的JS知识点总结
2019/03/15 Javascript
微信小程序时间戳转日期的详解
2019/04/30 Javascript
[46:23]OG vs EG 2018国际邀请赛淘汰赛BO3 第一场 8.23
2018/08/24 DOTA
开源软件包和环境管理系统Anaconda的安装使用
2017/09/04 Python
解决tensorflow测试模型时NotFoundError错误的问题
2018/07/26 Python
Python datetime包函数简单介绍
2019/08/28 Python
用python求一重积分和二重积分的例子
2019/12/06 Python
解决torch.autograd.backward中的参数问题
2020/01/07 Python
在keras中实现查看其训练loss值
2020/06/16 Python
flask开启多线程的具体方法
2020/08/02 Python
美国瑜伽品牌:Gaiam
2017/10/31 全球购物
乌克兰在线药房:Аптека24
2019/10/30 全球购物
花卉与景观设计系大学生求职信
2013/10/01 职场文书
产品工艺师的岗位职责
2013/11/15 职场文书
家长对孩子评语
2014/01/30 职场文书
弘扬雷锋精神演讲稿
2014/05/10 职场文书
运动会口号16字
2014/06/07 职场文书
学校联谊协议书
2014/09/16 职场文书
党员个人剖析材料
2014/09/30 职场文书
汉字听写大会观后感
2015/06/12 职场文书
2016年公司新年寄语
2015/08/17 职场文书
导游词之金鞭溪风景区
2019/09/12 职场文书
面试官问我Mysql的存储引擎了解多少
2022/08/05 MySQL
CSS 鼠标选中文字后改变背景色的实现代码
2023/05/21 HTML / CSS