Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置函数之filter map reduce介绍
Nov 30 Python
将Python代码嵌入C++程序进行编写的实例
Jul 31 Python
Python爬虫抓取手机APP的传输数据
Jan 22 Python
Python字符串格式化的方法(两种)
Sep 19 Python
Python对数据进行插值和下采样的方法
Jul 03 Python
Python 3.x 判断 dict 是否包含某键值的实例讲解
Jul 06 Python
Python hexstring-list-str之间的转换方法
Jun 12 Python
解决python 3 urllib 没有 urlencode 属性的问题
Aug 22 Python
Python列表list常用内建函数实例小结
Oct 22 Python
Django框架实现在线考试系统的示例代码
Nov 30 Python
python 读取yaml文件的两种方法(在unittest中使用)
Dec 01 Python
Python获取字典中某个key的value
Apr 13 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
PHP设计模式(九)外观模式Facade实例详解【结构型】
2020/05/02 PHP
PHP 超级全局变量相关总结
2020/06/30 PHP
用JavaScript实现单继承和多继承的简单方法
2009/03/29 Javascript
javascript 限制输入脚本大全
2009/11/03 Javascript
屏蔽网页右键复制和ctrl+c复制的js代码
2013/01/04 Javascript
网站内容禁止复制和粘贴、另存为的js代码
2014/02/26 Javascript
JavaScript简单表格编辑功能实现方法
2015/04/16 Javascript
JavaScript+html5 canvas制作色彩斑斓的正方形效果
2016/01/27 Javascript
Bootstrap所支持的表单控件实例详解
2016/05/16 Javascript
KnockoutJs快速入门教程
2016/05/16 Javascript
AngularJs解决跨域问题案例详解(简单方法)
2016/05/19 Javascript
jQuery图片切换动画特效
2016/11/02 Javascript
js 单引号替换成双引号,双引号替换成单引号的实现方法
2017/02/16 Javascript
将input框中输入内容显示在相应的div中【三种方法可选】
2017/05/08 Javascript
vue实现手机号码抽奖上下滚动动画示例
2017/10/18 Javascript
在Vue中使用Compass的方法
2018/03/02 Javascript
Webpack 之 babel-loader文件预处理器详解
2018/03/23 Javascript
详解从Vue-router到html5的pushState
2018/07/21 Javascript
浅析vue 函数配置项watch及函数 $watch 源码分享
2018/11/22 Javascript
Vue CLI3创建项目部署到Tomcat 使用ngrok映射到外网
2019/05/16 Javascript
vue项目创建步骤及路由router
2020/01/14 Javascript
jQuery+ThinkPHP实现图片上传
2020/07/23 jQuery
[01:05:40]2014 DOTA2国际邀请赛中国区预选赛 5 23 CIS VS DT第三场
2014/05/24 DOTA
[01:39]2014DOTA2国际邀请赛 Newbee经理CU专访队伍火力全开
2014/07/15 DOTA
举例讲解Python程序与系统shell交互的方式
2015/04/09 Python
在Mac上删除自己安装的Python方法
2018/10/29 Python
利用PyCharm Profile分析异步爬虫效率详解
2019/05/08 Python
一篇文章搞懂python的转义字符及用法
2020/09/03 Python
python实现人性化显示金额数字实例详解
2020/09/25 Python
python装饰器代码深入讲解
2021/03/01 Python
金智子午JAVA面试题
2015/09/04 面试题
教学改革实施方案
2014/03/31 职场文书
舞蹈教育学专业自荐信
2014/06/15 职场文书
禁止酒驾标语
2014/06/25 职场文书
个人年终总结开头
2015/03/06 职场文书
升学宴学生致辞
2015/07/27 职场文书