PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础入门学习笔记(Python环境搭建)
Jan 13 Python
Python制作词云的方法
Jan 03 Python
python模仿网页版微信发送消息功能
Feb 24 Python
用tensorflow搭建CNN的方法
Mar 05 Python
基于pandas数据样本行列选取的方法
Apr 20 Python
python爬虫基础教程:requests库(二)代码实例
Apr 09 Python
解决安装python3.7.4报错Can''t connect to HTTPS URL because the SSL module is not available
Jul 31 Python
Python实现图片添加文字
Nov 26 Python
Python3批量创建Crowd用户并分配组
May 20 Python
Selenium之模拟登录铁路12306的示例代码
Jul 31 Python
python中sys模块是做什么用的
Aug 16 Python
Python使用openpyxl模块处理Excel文件
Jun 05 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
php利用header函数实现文件下载时直接提示保存
2009/11/12 PHP
php curl选项列表(超详细)
2013/07/01 PHP
PHP防止跨域提交表单
2013/11/01 PHP
Yii实现多数据库主从读写分离的方法
2014/12/29 PHP
php 静态属性和静态方法区别详解
2017/04/09 PHP
laravel-admin 在列表页添加自定义按钮的例子
2019/09/30 PHP
JavaScript实现动态增加文件域表单
2009/02/12 Javascript
jquery last-child 列表最后一项的样式
2010/01/22 Javascript
JQuery之拖拽插件实现代码
2011/04/14 Javascript
Jquery实现搜索框提示功能示例代码
2013/08/13 Javascript
HTML页面滚动时获取离页面顶部的距离2种实现方法
2013/09/05 Javascript
Javascript实现视频轮播在pc端与移动端均可
2013/09/29 Javascript
node.js中的fs.link方法使用说明
2014/12/15 Javascript
jQuery可见性过滤选择器用法示例
2016/09/09 Javascript
微信JS-SDK自定义分享功能实例详解【分享给朋友/分享到朋友圈】
2016/11/25 Javascript
在node中如何使用 ES6
2017/04/22 Javascript
vue 父组件调用子组件方法及事件
2018/03/29 Javascript
简化vuex的状态管理方案的方法
2018/06/02 Javascript
vue+element模态框中新增模态框和删除功能
2019/06/11 Javascript
easyUI使用分页过滤器对数据进行分页操作实例分析
2020/06/01 Javascript
举例讲解Python面向对象编程中类的继承
2016/06/17 Python
基于Python的关键字监控及告警
2017/07/06 Python
PyQt 图解Qt Designer工具的使用方法
2019/08/06 Python
Python3实现zip分卷压缩过程解析
2019/10/09 Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
2020/06/23 Python
python爬虫---requests库的用法详解
2020/09/28 Python
SQL里面如何插入自动增长序列号字段
2012/03/29 面试题
工商企业管理应届生求职信
2014/05/04 职场文书
2014年销售经理工作总结
2014/12/01 职场文书
优秀班主任事迹材料
2014/12/16 职场文书
实习班主任自我评价
2015/03/11 职场文书
上班迟到检讨书
2015/05/06 职场文书
用CSS3画一个爱心
2021/04/27 HTML / CSS
Java8中接口的新特性使用指南
2021/11/01 Java/Android
Spring Boot 底层原理基础深度解析
2022/04/03 Java/Android
使用 Koa + TS + ESLlint 搭建node服务器的过程详解
2022/05/30 NodeJs