pytorch 使用半精度模型部署的操作


Posted in Python onMay 24, 2021

背景

pytorch作为深度学习的计算框架正得到越来越多的应用.

我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.

在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.

但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍

具体方法

在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可

另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.

补充:pytorch 使用amp.autocast半精度加速训练

准备工作

pytorch 1.6+

如何使用autocast?

根据官方提供的方法,

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。

因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
     with autocast():
      # code
     return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

1、matmul

2、addbmm

3、addmm

4、addmv

5、addr

6、baddbmm

7、bmm

8、chain_matmul

9、conv1d

10、conv2d

11、conv3d

12、conv_transpose1d

13、conv_transpose2d

14、conv_transpose3d

15、linear

16、matmul

17、mm

18、mv

19、prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。

Python 相关文章推荐
Python读取环境变量的方法和自定义类分享
Nov 22 Python
Python自动化构建工具scons使用入门笔记
Mar 10 Python
介绍Python中的一些高级编程技巧
Apr 02 Python
python正常时间和unix时间戳相互转换的方法
Apr 23 Python
深入理解Python分布式爬虫原理
Nov 23 Python
Python爬虫实现爬取京东手机页面的图片(实例代码)
Nov 30 Python
python自动查询12306余票并发送邮箱提醒脚本
May 21 Python
python微信公众号之关注公众号自动回复
Oct 25 Python
python 发送和接收ActiveMQ消息的实例
Jan 30 Python
Python 时间戳之获取整点凌晨时间戳的操作方法
Jan 28 Python
Kmeans均值聚类算法原理以及Python如何实现
Sep 26 Python
Django利用AJAX技术实现博文实时搜索
May 06 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 #Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
You might like
英雄试炼之肉山谷—引领RPG新潮流
2020/04/20 DOTA
PHP return语句的另一个作用
2014/07/30 PHP
PHP实现时间比较和时间差计算的方法示例
2017/07/24 PHP
php使用curl下载指定大小的文件实例代码
2017/09/30 PHP
PHP关于foreach复制知识点总结
2019/01/28 PHP
快速保存网页中所有图片的方法
2006/06/23 Javascript
window.location和document.location的区别分析
2008/12/23 Javascript
Jquery 快速构建可拖曳的购物车DragDrop
2009/11/30 Javascript
基于jquery的复制网页内容到WORD的实现代码
2011/02/16 Javascript
深入理解JavaScript系列(6) 强大的原型和原型链
2012/01/15 Javascript
hover的用法及live的用法介绍(鼠标悬停效果)
2013/03/29 Javascript
把input初始值不写value的具体实现方法
2013/07/04 Javascript
angularjs学习笔记之三大模块(modal,controller,view)
2015/09/26 Javascript
jQuery Easyui实现左右布局
2016/01/26 Javascript
详解JS几种变量交换方式以及性能分析对比
2016/11/25 Javascript
Node中使用ES6语法的基础教程
2018/01/05 Javascript
微信小程序图片左右摆动效果详解
2019/07/13 Javascript
JS 实现发送短信验证码的“59秒后重新发送验证短信”功能
2019/08/23 Javascript
antd-mobile ListView长列表的数据更新遇到的坑
2020/04/08 Javascript
python开发的小球完全弹性碰撞游戏代码
2013/10/15 Python
Python机器学习之决策树算法实例详解
2017/12/06 Python
Python之多线程爬虫抓取网页图片的示例代码
2018/01/10 Python
python编程使用selenium模拟登陆淘宝实例代码
2018/01/25 Python
详解Python中的动态属性和特性
2018/04/07 Python
python 调用钉钉机器人的方法
2019/02/20 Python
如何使用django的MTV开发模式返回一个网页
2019/07/22 Python
如何在VSCode上轻松舒适的配置Python的方法步骤
2019/10/28 Python
Python终端输出彩色字符方法详解
2020/02/11 Python
Win 10下Anaconda虚拟环境的教程
2020/05/18 Python
Windows 平台做 Python 开发的最佳组合(推荐)
2020/07/27 Python
python如何调用百度识图api
2020/09/29 Python
Sport-Thieme荷兰:购买体育用品
2019/08/25 全球购物
工程造价自荐信
2013/10/09 职场文书
企业内控岗位的职责
2014/02/07 职场文书
高三家长寄语
2014/04/03 职场文书
副科级后备干部考察材料
2014/05/15 职场文书