pytorch 使用半精度模型部署的操作


Posted in Python onMay 24, 2021

背景

pytorch作为深度学习的计算框架正得到越来越多的应用.

我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.

在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.

但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍

具体方法

在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可

另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.

补充:pytorch 使用amp.autocast半精度加速训练

准备工作

pytorch 1.6+

如何使用autocast?

根据官方提供的方法,

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。

因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
     with autocast():
      # code
     return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

1、matmul

2、addbmm

3、addmm

4、addmv

5、addr

6、baddbmm

7、bmm

8、chain_matmul

9、conv1d

10、conv2d

11、conv3d

12、conv_transpose1d

13、conv_transpose2d

14、conv_transpose3d

15、linear

16、matmul

17、mm

18、mv

19、prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。

Python 相关文章推荐
Python设计模式之代理模式实例
Apr 26 Python
python通过pip更新所有已安装的包实现方法
May 19 Python
单链表反转python实现代码示例
Feb 08 Python
python在回调函数中获取返回值的方法
Feb 22 Python
深入了解Django中间件及其方法
Jul 26 Python
导入tensorflow时报错:cannot import name 'abs'的解决
Oct 10 Python
python实现PCA降维的示例详解
Feb 24 Python
Pytorch转onnx、torchscript方式
May 25 Python
win10下python3.8的PIL库安装过程
Jun 08 Python
基于keras中的回调函数用法说明
Jun 17 Python
Python下使用Trackbar实现绘图板
Oct 27 Python
python playwright之元素定位示例详解
Jul 23 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 #Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
You might like
PHP获取163、gmail、126等邮箱联系人地址【已测试2009.10.10】
2009/10/11 PHP
php 各种应用乱码问题的解决方法
2010/05/09 PHP
javascript 火狐(firefox)不显示本地图片问题解决
2008/07/05 Javascript
jquery 图片预加载 自动等比例缩放插件
2008/12/25 Javascript
javascript EXCEL 操作类代码
2009/07/30 Javascript
Lazy Load 延迟加载图片的 jQuery 插件
2010/02/06 Javascript
基于jquery tab切换(防止页面刷新)
2012/05/23 Javascript
jquery 简单应用示例总结
2013/08/09 Javascript
简单选项卡 js和jquery制作方法分享
2014/02/26 Javascript
jQuery制作简单柱状图实例
2015/01/28 Javascript
jquery衣服颜色选取插件效果代码分享
2015/08/28 Javascript
javascript字符串函数汇总
2015/12/06 Javascript
jQuery侧边栏实现代码
2016/05/06 Javascript
微信小程序 loading(加载中提示框)实例
2016/10/28 Javascript
JavaScript实现图片懒加载(Lazyload)
2016/11/28 Javascript
Vue scrollBehavior 滚动行为实现后退页面显示在上次浏览的位置
2019/05/27 Javascript
如何通过JS实现转码与解码
2020/02/21 Javascript
webpack安装配置与常见使用过程详解(结合vue)
2020/06/01 Javascript
详细分析Node.js 多进程
2020/06/22 Javascript
详解Python多线程
2016/11/14 Python
Python列表和元组的定义与使用操作示例
2017/07/26 Python
python输出100以内的质数与合数实例代码
2018/07/08 Python
python 输出列表元素实例(以空格/逗号为分隔符)
2019/12/25 Python
Python TCPServer 多线程多客户端通信的实现
2019/12/31 Python
Python实现括号匹配方法详解
2020/02/10 Python
python将logging模块封装成单独模块并实现动态切换Level方式
2020/05/12 Python
numpy的Fancy Indexing和array比较详解
2020/06/11 Python
意大利值得信赖的在线超级药房:PillolaStore
2020/02/05 全球购物
车间班长岗位职责
2013/11/30 职场文书
领导班子党的群众路线对照检查材料
2014/09/25 职场文书
2015年村党支部工作总结
2015/04/30 职场文书
学校2015年纠风工作总结
2015/05/15 职场文书
2015年幼儿园中班下学期工作总结
2015/05/22 职场文书
党员观看《筑梦中国》心得体会
2016/01/18 职场文书
Redis高级数据类型Hyperloglog、Bitmap的使用
2021/05/24 Redis
Windows Server 2012 R2服务器安装与配置的完整步骤
2022/07/15 Servers