Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
解决python3中cv2读取中文路径的问题
Dec 05 Python
python使用pdfminer解析pdf文件的方法示例
Dec 20 Python
Python直接赋值、浅拷贝与深度拷贝实例分析
Jun 18 Python
ML神器:sklearn的快速使用及入门
Jul 11 Python
Python3分析处理声音数据的例子
Aug 27 Python
Python 支持向量机分类器的实现
Jan 15 Python
Python3标准库glob文件名模式匹配的问题
Mar 13 Python
Django更新models数据库结构步骤
Apr 01 Python
python函数中将变量名转换成字符串实例
May 11 Python
浅谈python量化 双均线策略(金叉死叉)
Jun 03 Python
Python基于数列实现购物车程序过程详解
Jun 09 Python
Python基于time模块表示时间常用方法
Jun 18 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
php类中private属性继承问题分析
2012/11/01 PHP
PHP输出九九乘法表代码实例
2015/03/27 PHP
ThinkPHP数据操作方法总结
2015/09/28 PHP
关于文本限制字数的js代码
2007/04/02 Javascript
Javascript学习笔记之 对象篇(三) : hasOwnProperty
2014/06/24 Javascript
详解JavaScript的Polymer框架中的通知交互
2015/07/29 Javascript
jquery+CSS3实现3D拖拽相册效果
2016/07/18 Javascript
探讨AngularJs中ui.route的简单应用
2016/11/16 Javascript
微信小程序 input表单与redio及下拉列表的使用实例
2017/09/20 Javascript
使用D3.js创建物流地图的示例代码
2018/01/27 Javascript
解决vue页面DOM操作不生效的问题
2018/03/17 Javascript
jQuery实现的电子时钟效果完整示例
2018/04/28 jQuery
浅谈Angularjs中不同类型的双向数据绑定
2018/07/16 Javascript
swiper在vue项目中loop循环轮播失效的解决方法
2018/09/15 Javascript
基于vue+axios+lrz.js微信端图片压缩上传方法
2019/06/25 Javascript
Vue实现boradcast和dispatch的示例
2020/11/13 Javascript
node.js文件的复制、创建文件夹等相关操作
2021/02/05 Javascript
[56:12]LGD vs Optic Supermajor小组赛D组胜者组决赛 BO3 第一场 6.3
2018/06/04 DOTA
使用PDB简单调试Python程序简明指南
2015/04/25 Python
Python私有属性私有方法应用实例解析
2020/09/15 Python
Linux的文件类型
2016/07/05 面试题
什么是makefile? 如何编写makefile?
2013/01/02 面试题
Linux的文件类型
2012/03/07 面试题
学生个人自我鉴定范文
2014/03/28 职场文书
委托书如何写
2014/08/30 职场文书
个人批评与自我批评
2014/10/15 职场文书
工作失误检讨书(经典集锦版)
2014/10/17 职场文书
企业党建工作总结2015
2015/05/26 职场文书
奔腾年代观后感
2015/06/09 职场文书
无婚姻登记记录证明
2015/06/18 职场文书
幼儿园教师教育随笔
2015/08/14 职场文书
无故旷工检讨书
2015/08/15 职场文书
redis连接被拒绝的解决方案
2021/04/12 Redis
openstack中的rpc远程调用的方法
2021/07/09 Python
css实现两栏布局,左侧固定宽,右侧自适应的多种方法
2021/08/07 HTML / CSS
什么是SOLID
2022/03/24 Javascript