PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程之多态用法实例详解
May 19 Python
Python中字典和集合学习小结
Jul 07 Python
go和python变量赋值遇到的一个问题
Aug 31 Python
简单实现python数独游戏
Mar 30 Python
Python3.6笔记之将程序运行结果输出到文件的方法
Apr 22 Python
使用EduBlock轻松学习Python编程
Oct 08 Python
想学python 这5本书籍你必看!
Dec 11 Python
python模拟菜刀反弹shell绕过限制【推荐】
Jun 25 Python
pytorch-神经网络拟合曲线实例
Jan 15 Python
python range实例用法分享
Feb 06 Python
python中Ansible模块的Playbook的具体使用
May 28 Python
Python基于tkinter canvas实现图片裁剪功能
Nov 05 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
PHP-redis中文文档介绍
2013/02/07 PHP
php 数组随机取值的简单实例
2016/05/23 PHP
基于jQuery的简单的列表导航菜单
2011/03/02 Javascript
jQuery Form 页面表单提交的小例子
2013/11/15 Javascript
JavaScript定义类的几种方式总结
2014/01/06 Javascript
jquery中常用的函数和属性详细解析
2014/03/07 Javascript
jQuery检测某个元素是否存在代码分享
2015/07/09 Javascript
javascript作用域问题实例分析
2015/07/13 Javascript
AngularJS实现全选反选功能
2015/12/08 Javascript
浅谈jquery点击label触发2次的问题
2016/06/12 Javascript
js 基础篇必看(点击事件轮播图的简单实现)
2016/08/20 Javascript
Bootstrap 下拉多选框插件Bootstrap Multiselect
2017/01/22 Javascript
js实现增加数字显示的环形进度条效果
2017/02/05 Javascript
jquery.flot.js简单绘制折线图用法示例
2017/03/13 Javascript
基于Bootstrap table组件实现多层表头的实例代码
2017/09/07 Javascript
30分钟快速入门掌握ES6/ES2015的核心内容(下)
2018/04/18 Javascript
JS限制输入框输入的实现代码
2018/07/02 Javascript
Vue源码解析之数组变异的实现
2018/12/04 Javascript
解决Vue 刷新页面导航显示高亮位置不对问题
2019/12/25 Javascript
python登录豆瓣并发帖的方法
2015/07/08 Python
windows10系统中安装python3.x+scrapy教程
2016/11/08 Python
python3+requests接口自动化session操作方法
2018/10/13 Python
python实现朴素贝叶斯算法
2018/11/19 Python
python版百度语音识别功能
2019/07/09 Python
python3 字符串知识点学习笔记
2020/02/08 Python
使用Python文件读写,自定义分隔符(custom delimiter)
2020/07/05 Python
字中字效果的实现【html5实例】
2016/05/03 HTML / CSS
美国美发品牌:Bumble and Bumble
2016/10/08 全球购物
上班玩手机检讨书
2014/02/17 职场文书
求职信内容怎么写
2014/05/26 职场文书
纪念9.18事变演讲稿
2014/09/14 职场文书
专项资金申请报告
2015/05/15 职场文书
重温入党誓词主持词
2015/06/29 职场文书
诚信高考倡议书
2019/06/24 职场文书
Pandas加速代码之避免使用for循环
2021/05/30 Python
mybatis 解决从列名到属性名的自动映射失败问题
2021/06/30 Java/Android