pytorch 使用半精度模型部署的操作


Posted in Python onMay 24, 2021

背景

pytorch作为深度学习的计算框架正得到越来越多的应用.

我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.

在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.

但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍

具体方法

在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可

另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.

补充:pytorch 使用amp.autocast半精度加速训练

准备工作

pytorch 1.6+

如何使用autocast?

根据官方提供的方法,

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。

因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
     with autocast():
      # code
     return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

1、matmul

2、addbmm

3、addmm

4、addmv

5、addr

6、baddbmm

7、bmm

8、chain_matmul

9、conv1d

10、conv2d

11、conv3d

12、conv_transpose1d

13、conv_transpose2d

14、conv_transpose3d

15、linear

16、matmul

17、mm

18、mv

19、prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。

Python 相关文章推荐
一个计算身份证号码校验位的Python小程序
Aug 15 Python
python使用socket连接远程服务器的方法
Apr 29 Python
Request的中断和ErrorHandler实例解析
Feb 12 Python
对python使用http、https代理的实例讲解
May 07 Python
matplotlib 纵坐标轴显示数据值的实例
May 25 Python
python通过TimedRotatingFileHandler按时间切割日志
Jul 17 Python
Python GUI学习之登录系统界面篇
Aug 21 Python
python3.5 cv2 获取视频特定帧生成jpg图片
Aug 28 Python
Python Numpy 控制台完全输出ndarray的实现
Feb 19 Python
Python通过Pillow实现图片对比
Apr 29 Python
python TCP包注入方式
May 05 Python
Python模拟登入的N种方式(建议收藏)
May 31 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 #Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
You might like
php 采集书并合成txt格式的实现代码
2009/03/01 PHP
php通过字符串调用函数示例
2014/03/02 PHP
php生成年月日下载列表的方法
2015/04/24 PHP
PHP curl伪造IP地址和header信息代码实例
2015/04/27 PHP
php使用curl_init()和curl_multi_init()多线程的速度比较详解
2018/08/15 PHP
一个很简单的办法实现TD的加亮效果.
2006/06/29 Javascript
关于jQuery的inArray 方法介绍
2011/10/08 Javascript
jquery实现点击文字可编辑并修改保存至数据库
2014/04/15 Javascript
jQuery前端分页示例分享
2015/02/10 Javascript
jquery实现模拟百分比进度条渐变效果代码
2015/10/29 Javascript
JS中setTimeout的巧妙用法前端函数节流
2016/03/24 Javascript
AngularJS控制器继承自另一控制器
2016/05/09 Javascript
jquery使用on绑定a标签无效 只能用live解决
2016/06/02 Javascript
AngularJS 指令详细介绍
2016/07/27 Javascript
jQuery 的 ready()的纯js替代方法
2016/11/20 Javascript
js实现随机数字字母验证码
2017/06/19 Javascript
linux 后台运行node服务指令方法
2018/05/23 Javascript
vue二级菜单导航点击选中事件的方法
2018/09/12 Javascript
微信小程序使用蓝牙小插件
2019/09/23 Javascript
vue element-ui读取pdf文件的方法
2019/11/26 Javascript
vue实现给div绑定keyup的enter事件
2020/07/31 Javascript
[27:39]Ti4 循环赛第二日 LGD vs Fnatic
2014/07/11 DOTA
Python二维码生成库qrcode安装和使用示例
2014/12/16 Python
Python生成随机验证码的两种方法
2015/12/22 Python
Python安装官方whl包和tar.gz包的方法(推荐)
2017/06/04 Python
Python判断文件和字符串编码类型的实例
2017/12/21 Python
python 基本数据类型占用内存空间大小的实例
2018/06/12 Python
python实现简易淘宝购物
2019/11/22 Python
python 检测图片是否有马赛克
2020/12/01 Python
Python关于拓扑排序知识点讲解
2021/01/04 Python
苏格兰销售女装、男装和童装的连锁店:M&Co
2018/03/16 全球购物
公务员职务工作的自我评价
2013/11/01 职场文书
理工大学毕业生自荐信
2013/11/01 职场文书
老总助理工作岗位职责
2014/02/06 职场文书
致200米运动员广播稿
2014/02/06 职场文书
2021好看的国漫排行榜前十名 《完美世界》上榜,《元龙》排名第一
2022/03/18 国漫