解决Pytorch半精度浮点型网络训练的问题


Posted in Python onMay 24, 2021

用Pytorch1.0进行半精度浮点型网络训练需要注意下问题:

1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()

2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可

3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。

另外,SGD算法对于半精度和全精度计算均没有问题。

还有一个问题是不知道是不是网络结构比较小的原因,使用半精度的训练速度还没有全精度快。这个值得后续进一步探索。

对于上面的这个问题,的确是网络很小的情况下,在1080Ti上半精度浮点型没有很明显的优势,但是当网络变大之后,半精度浮点型要比全精度浮点型要快。

但具体快多少和模型的大小以及输入样本大小有关系,我测试的是要快1/6,同时,半精度浮点型在占用内存上比较有优势,对于精度的影响尚未探究。

将网络再变大些,epoch的次数也增大,半精度和全精度的时间差就表现出来了,在训练的时候。

补充:pytorch半精度,混合精度,单精度训练的区别amp.initialize

看代码吧~

mixed_precision = True
try:  # Mixed precision training https://github.com/NVIDIA/apex
    from apex import amp
except:
    mixed_precision = False  # not installed

 model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=1)

为了帮助提高Pytorch的训练效率,英伟达提供了混合精度训练工具Apex。号称能够在不降低性能的情况下,将模型训练的速度提升2-4倍,训练显存消耗减少为之前的一半。

文档地址是:https://nvidia.github.io/apex/index.html

该 工具 提供了三个功能,amp、parallel和normalization。由于目前该工具还是0.1版本,功能还是很基础的,在最后一个normalization功能中只提供了LayerNorm层的复现,实际上在后续的使用过程中会发现,出现问题最多的是pytorch的BN层。

第二个工具是pytorch的分布式训练的复现,在文档中描述的是和pytorch中的实现等价,在代码中可以选择任意一个使用,实际使用过程中发现,在使用混合精度训练时,使用Apex复现的parallel工具,能避免一些bug。

默认训练方式是 单精度float32

import torch
model = torch.nn.Linear(D_in, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

半精度 model(img.half())

import torch
model = torch.nn.Linear(D_in, D_out).half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img.half())
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

接下来是混合精度的实现,这里主要用到Apex的amp工具。代码修改为:

加上这一句封装,

model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)
import torch
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 # loss.backward()
 with amp.scale_loss(loss, optimizer) as scaled_loss:
     scaled_loss.backward()

 optimizer.step()
 optimizer.zero_grad()

实际流程为:调用amp.initialize按照预定的opt_level对model和optimizer进行设置。在计算loss时使用amp.scale_loss进行回传。

需要注意以下几点:

在调用amp.initialize之前,模型需要放在GPU上,也就是需要调用cuda()或者to()。

在调用amp.initialize之前,模型不能调用任何分布式设置函数。

此时输入数据不需要在转换为半精度。

在使用混合精度进行计算时,最关键的参数是opt_level。他一共含有四种设置值:‘00',‘01',‘02',‘03'。实际上整个amp.initialize的输入参数很多:

但是在实际使用过程中发现,设置opt_level即可,这也是文档中例子的使用方法,甚至在不同的opt_level设置条件下,其他的参数会变成无效。(已知BUG:使用‘01'时设置keep_batchnorm_fp32的值会报错)

概括起来:

00相当于原始的单精度训练。01在大部分计算时采用半精度,但是所有的模型参数依然保持单精度,对于少数单精度较好的计算(如softmax)依然保持单精度。02相比于01,将模型参数也变为半精度。

03基本等于最开始实验的全半精度的运算。值得一提的是,不论在优化过程中,模型是否采用半精度,保存下来的模型均为单精度模型,能够保证模型在其他应用中的正常使用。这也是Apex的一大卖点。

在Pytorch中,BN层分为train和eval两种操作。

实现时若为单精度网络,会调用CUDNN进行计算加速。常规训练过程中BN层会被设为train。Apex优化了这种情况,通过设置keep_batchnorm_fp32参数,能够保证此时BN层使用CUDNN进行计算,达到最好的计算速度。

但是在一些fine tunning场景下,BN层会被设为eval(我的模型就是这种情况)。此时keep_batchnorm_fp32的设置并不起作用,训练会产生数据类型不正确的bug。此时需要人为的将所有BN层设置为半精度,这样将不能使用CUDNN加速。

一个设置的参考代码如下:

def fix_bn(m):
 classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
     m.eval().half()

model.apply(fix_bn)

实际测试下来,最后的模型准确度上感觉差别不大,可能有轻微下降;时间上变化不大,这可能会因不同的模型有差别;显存开销上确实有很大的降低。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python模拟登陆阿里妈妈生成商品推广链接
Apr 03 Python
python批量替换页眉页脚实例代码
Jan 22 Python
TensorFlow损失函数专题详解
Apr 26 Python
基于python的多进程共享变量正确打开方式
Apr 28 Python
浅谈python 读excel数值为浮点型的问题
Dec 25 Python
Python3.5实现的三级菜单功能示例
Mar 25 Python
python反编译学习之字节码详解
May 19 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 Python
Python制作简易版小工具之计算天数的实现思路
Feb 13 Python
python tqdm 实现滚动条不上下滚动代码(保持一行内滚动)
Feb 19 Python
python3中datetime库,time库以及pandas中的时间函数区别与详解
Apr 16 Python
python rsa-oaep加密的示例代码
Sep 23 Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
You might like
PHP通过session id 实现session共享和登录验证的代码
2012/06/03 PHP
php读取文件内容的三种可行方法示例介绍
2014/02/08 PHP
Zend Framework过滤器Zend_Filter用法详解
2016/12/09 PHP
laravel-admin select框默认选中的方法
2019/10/03 PHP
二行代码解决全部网页木马
2008/03/28 Javascript
8个实用的jQuery技巧
2014/03/04 Javascript
jquery删除数据记录时的弹出提示效果
2014/05/06 Javascript
jQuery实用技巧必备(下)
2015/11/03 Javascript
JS如何设置元素样式的方法示例
2017/08/28 Javascript
layer更改皮肤的实现方法
2019/09/11 Javascript
Vue 请求传公共参数的操作
2020/07/31 Javascript
JS异步宏队列微队列原理详解
2020/09/09 Javascript
vue自定义插件封装,实现简易的elementUi的Message和MessageBox的示例
2020/11/20 Vue.js
[37:29]完美世界DOTA2联赛PWL S2 LBZS vs Forest 第二场 11.19
2020/11/19 DOTA
Python装饰器使用示例及实际应用例子
2015/03/06 Python
Python使用Srapy框架爬虫模拟登陆并抓取知乎内容
2016/07/02 Python
Python实现螺旋矩阵的填充算法示例
2017/12/28 Python
Python爬虫之pandas基本安装与使用方法示例
2018/08/08 Python
python的中异常处理机制
2018/08/30 Python
python如何获取列表中每个元素的下标位置
2019/07/01 Python
python 日期排序的实例代码
2019/07/11 Python
Python3 requests文件下载 期间显示文件信息和下载进度代码实例
2019/08/16 Python
python多线程高级锁condition简单用法示例
2019/11/07 Python
python中如何使用insert函数
2020/01/09 Python
40个你可能不知道的Python技巧附代码
2020/01/29 Python
孕妇内衣和胸罩:Cake Maternity
2018/07/16 全球购物
网站开发实习生的自我评价
2013/12/11 职场文书
请假条的格式
2014/04/11 职场文书
行政专员岗位职责说明书
2014/07/30 职场文书
2014年小学语文工作总结
2014/12/20 职场文书
2015年护士工作总结范文
2015/03/31 职场文书
重阳节简报
2015/07/20 职场文书
教务处教学工作总结
2015/08/10 职场文书
预防职务犯罪警示教育心得体会
2016/01/15 职场文书
win10+anaconda安装yolov5的方法及问题解决方案
2021/04/29 Python
python自动化操作之动态验证码、滑动验证码的降噪和识别
2021/08/30 Python