pytorch 使用半精度模型部署的操作


Posted in Python onMay 24, 2021

背景

pytorch作为深度学习的计算框架正得到越来越多的应用.

我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.

在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.

但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍

具体方法

在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可

另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.

补充:pytorch 使用amp.autocast半精度加速训练

准备工作

pytorch 1.6+

如何使用autocast?

根据官方提供的方法,

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。

因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
     with autocast():
      # code
     return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

1、matmul

2、addbmm

3、addmm

4、addmv

5、addr

6、baddbmm

7、bmm

8、chain_matmul

9、conv1d

10、conv2d

11、conv3d

12、conv_transpose1d

13、conv_transpose2d

14、conv_transpose3d

15、linear

16、matmul

17、mm

18、mv

19、prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。

Python 相关文章推荐
Python中列表(list)操作方法汇总
Aug 18 Python
python实现从web抓取文档的方法
Sep 26 Python
python字符串对其居中显示的方法
Jul 11 Python
Python爬虫之xlml解析库(全面了解)
Aug 08 Python
用python实现k近邻算法的示例代码
Sep 06 Python
python的pytest框架之命令行参数详解(上)
Jun 27 Python
pytorch多GPU并行运算的实现
Sep 27 Python
django框架中ajax的使用及避开CSRF 验证的方式详解
Dec 11 Python
手把手教你如何用Pycharm2020.1.1配置远程连接的详细步骤
Aug 07 Python
Python如何使用ConfigParser读取配置文件
Nov 12 Python
python爬虫 requests-html的使用
Nov 30 Python
python实现批量移动文件
Apr 05 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 #Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
You might like
层叠菜单的动态生成
2006/10/09 PHP
php完全过滤HTML,JS,CSS等标签
2009/01/16 PHP
PHP实现QQ登录实例代码
2016/01/14 PHP
Javascript----文件操作
2007/01/18 Javascript
使用隐藏的new来创建对象
2011/03/29 Javascript
jQuery对html元素取值与赋值的方法
2013/11/20 Javascript
原生js代码实现图片放大境效果
2016/10/30 Javascript
AngularJs上传前预览图片的实例代码
2017/01/20 Javascript
在小程序中使用Echart图表的示例代码
2018/08/02 Javascript
JavaScript 作用域scope简单汇总
2019/10/23 Javascript
[02:25]DOTA2英雄基础教程 熊战士
2014/01/03 DOTA
[42:32]DOTA2上海特级锦标赛B组资格赛#2 Fnatic VS Spirit第二局
2016/02/27 DOTA
python 简易计算器程序,代码就几行
2009/08/29 Python
python实现根据用户输入从电影网站获取影片信息的方法
2015/04/07 Python
Python轻量级ORM框架Peewee访问sqlite数据库的方法详解
2017/07/20 Python
Python数据结构与算法之图结构(Graph)实例分析
2017/09/05 Python
基于Django filter中用contains和icontains的区别(详解)
2017/12/12 Python
matplotlib绘制动画代码示例
2018/01/02 Python
Python实现的朴素贝叶斯分类器示例
2018/01/06 Python
Python Json模块中dumps、loads、dump、load函数介绍
2018/05/15 Python
pyqt5 实现多窗口跳转的方法
2019/06/19 Python
python读取Excel表格文件的方法
2019/09/02 Python
python快速排序的实现及运行时间比较
2019/11/22 Python
日本7net购物网:书籍、漫画、杂志、DVD、游戏邮购
2017/02/17 全球购物
数据管理员的自我评价分享
2013/11/15 职场文书
求职信写作要突出重点
2014/01/01 职场文书
大学生职业生涯规划书参考模板
2014/03/05 职场文书
丧事主持词大全
2014/04/02 职场文书
公司募捐倡议书
2014/05/14 职场文书
理财学专业自荐书
2014/06/28 职场文书
自我查摆剖析材料
2014/10/11 职场文书
入党积极分子个人总结
2015/03/02 职场文书
学习焦裕禄先进事迹心得体会
2016/01/23 职场文书
Python 数据科学 Matplotlib图库详解
2021/07/07 Python
python装饰器代码解析
2022/03/23 Python
海康机器人重磅发布全新算法开发平台VM4.2
2022/04/21 数码科技