解决Pytorch半精度浮点型网络训练的问题


Posted in Python onMay 24, 2021

用Pytorch1.0进行半精度浮点型网络训练需要注意下问题:

1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()

2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可

3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。

另外,SGD算法对于半精度和全精度计算均没有问题。

还有一个问题是不知道是不是网络结构比较小的原因,使用半精度的训练速度还没有全精度快。这个值得后续进一步探索。

对于上面的这个问题,的确是网络很小的情况下,在1080Ti上半精度浮点型没有很明显的优势,但是当网络变大之后,半精度浮点型要比全精度浮点型要快。

但具体快多少和模型的大小以及输入样本大小有关系,我测试的是要快1/6,同时,半精度浮点型在占用内存上比较有优势,对于精度的影响尚未探究。

将网络再变大些,epoch的次数也增大,半精度和全精度的时间差就表现出来了,在训练的时候。

补充:pytorch半精度,混合精度,单精度训练的区别amp.initialize

看代码吧~

mixed_precision = True
try:  # Mixed precision training https://github.com/NVIDIA/apex
    from apex import amp
except:
    mixed_precision = False  # not installed

 model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=1)

为了帮助提高Pytorch的训练效率,英伟达提供了混合精度训练工具Apex。号称能够在不降低性能的情况下,将模型训练的速度提升2-4倍,训练显存消耗减少为之前的一半。

文档地址是:https://nvidia.github.io/apex/index.html

该 工具 提供了三个功能,amp、parallel和normalization。由于目前该工具还是0.1版本,功能还是很基础的,在最后一个normalization功能中只提供了LayerNorm层的复现,实际上在后续的使用过程中会发现,出现问题最多的是pytorch的BN层。

第二个工具是pytorch的分布式训练的复现,在文档中描述的是和pytorch中的实现等价,在代码中可以选择任意一个使用,实际使用过程中发现,在使用混合精度训练时,使用Apex复现的parallel工具,能避免一些bug。

默认训练方式是 单精度float32

import torch
model = torch.nn.Linear(D_in, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

半精度 model(img.half())

import torch
model = torch.nn.Linear(D_in, D_out).half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img.half())
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

接下来是混合精度的实现,这里主要用到Apex的amp工具。代码修改为:

加上这一句封装,

model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)
import torch
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 # loss.backward()
 with amp.scale_loss(loss, optimizer) as scaled_loss:
     scaled_loss.backward()

 optimizer.step()
 optimizer.zero_grad()

实际流程为:调用amp.initialize按照预定的opt_level对model和optimizer进行设置。在计算loss时使用amp.scale_loss进行回传。

需要注意以下几点:

在调用amp.initialize之前,模型需要放在GPU上,也就是需要调用cuda()或者to()。

在调用amp.initialize之前,模型不能调用任何分布式设置函数。

此时输入数据不需要在转换为半精度。

在使用混合精度进行计算时,最关键的参数是opt_level。他一共含有四种设置值:‘00',‘01',‘02',‘03'。实际上整个amp.initialize的输入参数很多:

但是在实际使用过程中发现,设置opt_level即可,这也是文档中例子的使用方法,甚至在不同的opt_level设置条件下,其他的参数会变成无效。(已知BUG:使用‘01'时设置keep_batchnorm_fp32的值会报错)

概括起来:

00相当于原始的单精度训练。01在大部分计算时采用半精度,但是所有的模型参数依然保持单精度,对于少数单精度较好的计算(如softmax)依然保持单精度。02相比于01,将模型参数也变为半精度。

03基本等于最开始实验的全半精度的运算。值得一提的是,不论在优化过程中,模型是否采用半精度,保存下来的模型均为单精度模型,能够保证模型在其他应用中的正常使用。这也是Apex的一大卖点。

在Pytorch中,BN层分为train和eval两种操作。

实现时若为单精度网络,会调用CUDNN进行计算加速。常规训练过程中BN层会被设为train。Apex优化了这种情况,通过设置keep_batchnorm_fp32参数,能够保证此时BN层使用CUDNN进行计算,达到最好的计算速度。

但是在一些fine tunning场景下,BN层会被设为eval(我的模型就是这种情况)。此时keep_batchnorm_fp32的设置并不起作用,训练会产生数据类型不正确的bug。此时需要人为的将所有BN层设置为半精度,这样将不能使用CUDNN加速。

一个设置的参考代码如下:

def fix_bn(m):
 classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
     m.eval().half()

model.apply(fix_bn)

实际测试下来,最后的模型准确度上感觉差别不大,可能有轻微下降;时间上变化不大,这可能会因不同的模型有差别;显存开销上确实有很大的降低。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Centos5.x下升级python到python2.7版本教程
Feb 14 Python
numpy向空的二维数组中添加元素的方法
Nov 01 Python
Python读取指定日期邮件的实例
Feb 01 Python
使用Python Pandas处理亿级数据的方法
Jun 24 Python
django迁移数据库错误问题解决
Jul 29 Python
Django ORM多对多查询方法(自定义第三张表&ManyToManyField)
Aug 09 Python
使用Python调取任意数字资产钱包余额功能
Aug 15 Python
python打开使用的方法
Sep 30 Python
python跨文件使用全局变量的实现
Nov 17 Python
python-图片流传输的思路及示例(url转换二维码)
Dec 21 Python
Django如何重置migration的几种情景
Feb 24 Python
python中的装饰器该如何使用
Jun 18 Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
You might like
ThinkPHP调用common/common.php函数提示错误function undefined的解决方法
2014/08/25 PHP
Yii使用find findAll查找出指定字段的实现方法
2014/09/05 PHP
niceTitle 基于jquery的超链接提示插件
2010/05/31 Javascript
jQuery UI Datepicker length为空或不是对象错误的解决方法
2010/12/19 Javascript
继续学习javascript闭包
2015/12/03 Javascript
jQuery判断checkbox选中状态
2016/05/12 Javascript
javascript中BOM基础知识总结
2017/02/14 Javascript
JavaScript中三个等号和两个等号你了解多少
2017/07/04 Javascript
微信小程序引用公共js里的方法的实例详解
2017/08/17 Javascript
Node.JS 循环递归复制文件夹目录及其子文件夹下的所有文件
2017/09/18 Javascript
在微信小程序中使用图表的方法示例
2019/04/25 Javascript
vue项目中mock.js的使用及基本用法
2019/05/22 Javascript
[17:36]VG战队纪录片
2014/08/21 DOTA
[01:57]DOTA2上海特锦赛小组赛解说单车采访花絮
2016/02/27 DOTA
[01:30]2016国际邀请赛中国区预选赛神秘商店火爆开启
2016/06/26 DOTA
[00:32]2018DOTA2亚洲邀请赛出场——LGD
2018/04/04 DOTA
[10:05]DOTA2-DPC中国联赛 正赛 iG vs PSG.LGD 选手采访
2021/03/11 DOTA
Python函数中定义参数的四种方式
2014/11/30 Python
python简单获取数组元素个数的方法
2015/07/13 Python
深入理解 Python 中的多线程 新手必看
2016/11/20 Python
python链接oracle数据库以及数据库的增删改查实例
2018/01/30 Python
python利用跳板机ssh远程连接redis的方法
2019/02/19 Python
python3通过selenium爬虫获取到dj商品的实例代码
2019/04/25 Python
tensorflow2.0与tensorflow1.0的性能区别介绍
2020/02/07 Python
Python图像处理库PIL的ImageGrab模块介绍详解
2020/02/26 Python
python获取linux系统信息的三种方法
2020/10/14 Python
使用css3绘制出各种几何图形
2016/08/17 HTML / CSS
Kingsoft金山公司C/C++笔试题
2016/05/10 面试题
后备干部考察材料
2014/02/12 职场文书
正风肃纪剖析材料
2014/02/18 职场文书
人力资源部经理的岗位职责
2014/03/04 职场文书
《真想变成大大的荷叶》教学反思
2014/04/14 职场文书
会议欢迎词
2015/01/23 职场文书
大学迎新生的欢迎词
2019/06/25 职场文书
Python pyecharts绘制条形图详解
2022/04/02 Python
Win11更新失败并提示0xc1900101
2022/04/19 数码科技