pytorch 使用半精度模型部署的操作


Posted in Python onMay 24, 2021

背景

pytorch作为深度学习的计算框架正得到越来越多的应用.

我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.

在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.

但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍

具体方法

在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可

另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.

补充:pytorch 使用amp.autocast半精度加速训练

准备工作

pytorch 1.6+

如何使用autocast?

根据官方提供的方法,

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。

因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
     with autocast():
      # code
     return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

1、matmul

2、addbmm

3、addmm

4、addmv

5、addr

6、baddbmm

7、bmm

8、chain_matmul

9、conv1d

10、conv2d

11、conv3d

12、conv_transpose1d

13、conv_transpose2d

14、conv_transpose3d

15、linear

16、matmul

17、mm

18、mv

19、prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。

Python 相关文章推荐
python中PIL安装简单教程
Apr 21 Python
Python数据结构之翻转链表
Feb 25 Python
利用pandas读取中文数据集的方法
Jul 25 Python
对python 判断数字是否小于0的方法详解
Jan 26 Python
Python3爬虫之自动查询天气并实现语音播报
Feb 21 Python
详解pytorch 0.4.0迁移指南
Jun 16 Python
python实现音乐播放器 python实现花框音乐盒子
Feb 25 Python
python实现可下载音乐的音乐播放器
Feb 25 Python
Python字符串split及rsplit方法原理详解
Jun 29 Python
python如何快速生成时间戳
Jul 21 Python
把Anaconda中的环境导入到Pycharm里面的方法步骤
Oct 30 Python
详解python网络进程
Jun 15 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 #Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
You might like
牡丹941资料
2021/03/01 无线电
GD输出汉字的函数的分析
2006/10/09 PHP
推荐一篇入门级的Class文章
2007/03/19 PHP
PHP限制页面只能在微信自带浏览器访问的代码
2014/01/15 PHP
使用配置类定义Codeigniter全局变量
2014/06/12 PHP
php图片的二进制转换实现方法
2014/12/15 PHP
使用XHGui来测试PHP性能的教程
2015/07/03 PHP
php抓取并保存网站图片的实现代码
2015/10/28 PHP
PHP从零开始打造自己的MVC框架之类的自动加载实现方法详解
2019/06/03 PHP
PHP基于swoole多进程操作示例
2019/08/12 PHP
php传值和传引用的区别点总结
2019/11/19 PHP
基于逻辑运算的简单权限系统(实现) JS 版
2007/03/24 Javascript
权威JavaScript 中的内存泄露模式
2007/08/13 Javascript
COM中获取JavaScript数组大小的代码
2009/11/22 Javascript
用js实现层随着内容大小动态渐变改变 推荐
2009/12/19 Javascript
JQuery的ajax获取数据后的处理总结(html,xml,json)
2010/07/14 Javascript
JS实现简易图片轮播效果的方法
2015/03/25 Javascript
浅析JavaScript动画模拟拖拽原理
2016/12/09 Javascript
通过vue-router懒加载解决首次加载时资源过多导致的速度缓慢问题
2018/04/08 Javascript
vue中使用vue-print.js实现多页打印
2020/03/05 Javascript
Vue实现手机扫描二维码预览页面效果
2020/05/28 Javascript
《javascript设计模式》学习笔记一:Javascript面向对象程序设计对象成员的定义分析
2020/04/07 Javascript
[42:23]完美世界DOTA2联赛PWL S3 Forest vs Rebirth 第二场 12.10
2020/12/13 DOTA
Python迭代器和生成器定义与用法示例
2018/02/10 Python
解决Python找不到ssl模块问题 No module named _ssl的方法
2019/04/29 Python
python 输出列表元素实例(以空格/逗号为分隔符)
2019/12/25 Python
python virtualenv虚拟环境配置与使用教程详解
2020/07/13 Python
英国剑桥包官网:The Cambridge Satchel Company
2016/08/01 全球购物
YesBabyOnline美国:全球性的在线婚纱礼服工厂
2018/05/05 全球购物
最新大学生创业计划书写作攻略
2014/04/02 职场文书
2014年英语工作总结
2014/12/20 职场文书
停电通知范文
2015/04/16 职场文书
2015年派出所民警工作总结
2015/04/24 职场文书
2015年学校综合治理工作总结
2015/07/20 职场文书
七年级写作指导之游记作文
2019/10/07 职场文书
用Python写一个简易版弹球游戏
2021/04/13 Python