pytorch 使用半精度模型部署的操作


Posted in Python onMay 24, 2021

背景

pytorch作为深度学习的计算框架正得到越来越多的应用.

我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.

在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.

但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍

具体方法

在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可

另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.

补充:pytorch 使用amp.autocast半精度加速训练

准备工作

pytorch 1.6+

如何使用autocast?

根据官方提供的方法,

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。

因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
     with autocast():
      # code
     return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

1、matmul

2、addbmm

3、addmm

4、addmv

5、addr

6、baddbmm

7、bmm

8、chain_matmul

9、conv1d

10、conv2d

11、conv3d

12、conv_transpose1d

13、conv_transpose2d

14、conv_transpose3d

15、linear

16、matmul

17、mm

18、mv

19、prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。

Python 相关文章推荐
python模拟登陆Tom邮箱示例分享
Jan 13 Python
Python使用BeautifulSoup库解析HTML基本使用教程
Mar 31 Python
Python多进程multiprocessing用法实例分析
Aug 18 Python
Python2实现的LED大数字显示效果示例
Sep 04 Python
Request的中断和ErrorHandler实例解析
Feb 12 Python
Python中常见的异常总结
Feb 20 Python
python解决js文件utf-8编码乱码问题(推荐)
May 02 Python
python实现12306登录并保存cookie的方法示例
Dec 17 Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 Python
python和go语言的区别是什么
Jul 20 Python
Numpy数组的广播机制的实现
Nov 03 Python
Python还能这么玩之用Python修改了班花的开机密码
Jun 04 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 #Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
You might like
2021年最新CPU天梯图
2021/03/04 数码科技
10款实用的PHP开源工具
2015/10/23 PHP
php中的explode()函数实例介绍
2019/01/18 PHP
JavaScript语句可以不以;结尾的烦恼
2007/03/08 Javascript
用JavaScript将从数据库中读取出来的日期型格式化为想要的类型。
2009/08/15 Javascript
ie中js创建checkbox默认选中问题探讨
2013/10/21 Javascript
Javascript实现返回上一页面并刷新的小例子
2013/12/11 Javascript
js验证输入是否为手机号码或电话号码示例
2013/12/30 Javascript
Javascript异步编程模型Promise模式详细介绍
2014/05/08 Javascript
JavaScript中遍历对象的property的3种方法介绍
2014/12/30 Javascript
JS解析XML文件和XML字符串详解
2015/04/17 Javascript
js检测iframe是否加载完成的方法
2015/11/26 Javascript
js实现简单的验证码
2015/12/25 Javascript
jQuery简介_动力节点Java学院整理
2017/07/04 jQuery
JavaScript中关于class的调用方法
2017/11/28 Javascript
如何用Node写页面爬虫的工具集
2018/10/26 Javascript
JS实现深度优先搜索求解两点间最短路径
2019/01/17 Javascript
如何获取vue单文件自身源码路径
2019/05/06 Javascript
Python遍历zip文件输出名称时出现乱码问题的解决方法
2015/04/08 Python
Jupyter安装nbextensions,启动提示没有nbextensions库
2020/04/23 Python
python验证码识别教程之利用投影法、连通域法分割图片
2018/06/04 Python
python+Splinter实现12306抢票功能
2018/09/25 Python
python获取地震信息 微信实时推送
2019/06/18 Python
Python实现简单的列表冒泡排序和反转列表操作示例
2019/07/10 Python
python脚本监控logstash进程并邮件告警实例
2020/04/28 Python
CSS3让登陆面板3D旋转起来
2016/05/03 HTML / CSS
旅游管理专业个人求职信范文
2013/12/24 职场文书
巾帼文明岗申报材料
2014/05/01 职场文书
党员教师一句话承诺
2014/05/30 职场文书
人力资源管理专业求职信
2014/07/23 职场文书
端午节活动总结
2014/08/26 职场文书
住房抵押登记委托书
2014/09/27 职场文书
车间主任岗位职责范本
2015/04/08 职场文书
现实表现证明材料
2015/06/19 职场文书
宣传委员竞选稿
2015/11/19 职场文书
中学生打架《检讨书》范文
2019/08/12 职场文书