Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作xml文件详细介绍
Jun 09 Python
Python对象转JSON字符串的方法
Apr 27 Python
python3+mysql查询数据并通过邮件群发excel附件
Feb 24 Python
python实现人脸识别经典算法(一) 特征脸法
Mar 13 Python
利用python将json数据转换为csv格式的方法
Mar 22 Python
python 接口返回的json字符串实例
Mar 27 Python
Python绘制的二项分布概率图示例
Aug 22 Python
Python用5行代码写一个自定义简单二维码
Oct 21 Python
Python全局锁中如何合理运用多线程(多进程)
Nov 06 Python
对tensorflow 中tile函数的使用详解
Feb 07 Python
Windows下Sqlmap环境安装教程详解
Aug 04 Python
Python多线程实用方法以及共享变量资源竞争问题
Apr 12 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
php处理斐波那契数列非递归方法
2012/02/04 PHP
MySQL 日期时间函数常用总结
2012/06/12 PHP
Session服务器配置指南与使用经验的深入解析
2013/06/17 PHP
使用php 获取时间今天明天昨天时间戳的详解
2013/06/20 PHP
isArray()函数(JavaScript中对象类型判断的几种方法)
2009/11/26 Javascript
如何使用jQuery技术开发ios风格的页面导航菜单
2015/07/29 Javascript
JQuery移动页面开发之屏幕方向改变与滚屏的实现
2015/12/03 Javascript
jQuery实现漂亮实用的商品图片tips提示框效果(无图片箭头+阴影)
2016/04/16 Javascript
JavaScript 点击触发复制功能实例详解
2018/11/02 Javascript
在Web关闭页面时发送Ajax请求的实现方法
2019/03/07 Javascript
js设置鼠标悬停改变背景色实现详解
2019/06/26 Javascript
vue实现简单跑马灯效果
2020/05/25 Javascript
vue实现购物车加减
2020/05/30 Javascript
[01:02]DOTA2辉夜杯决赛日 CDEC.Y对阵VG赛前花絮
2015/12/27 DOTA
Python中用pycurl监控http响应时间脚本分享
2015/02/02 Python
Python函数参数类型*、**的区别
2015/04/11 Python
进一步理解Python中的函数编程
2015/04/13 Python
Python中List.index()方法的使用教程
2015/05/20 Python
Python 40行代码实现人脸识别功能
2017/04/02 Python
Python 反转字符串(reverse)的方法小结
2018/02/20 Python
Python3实现带附件的定时发送邮件功能
2020/12/22 Python
Python入门之后再看点什么好?
2018/03/05 Python
Python使用add_subplot与subplot画子图操作示例
2018/06/01 Python
TensorFlow打印tensor值的实现方法
2018/07/27 Python
对python3标准库httpclient的使用详解
2018/12/18 Python
Django中间件基础用法详解
2019/07/18 Python
基于python框架Scrapy爬取自己的博客内容过程详解
2019/08/05 Python
使用Pytorch来拟合函数方式
2020/01/14 Python
世界顶级足球门票网站:Live Football Tickets
2017/10/14 全球购物
美国最好的钓鱼、狩猎和划船装备商店:Bass Pro Shops
2018/12/02 全球购物
Hotels.com英国:全球领先的酒店住宿提供商
2019/01/24 全球购物
JAVA程序员面试题
2012/10/03 面试题
Java中的类包括什么内容?设计时要注意哪些方面
2012/05/23 面试题
大学生职业生涯规划方案
2014/01/03 职场文书
安全协议书范本
2014/04/21 职场文书
pytorch 如何使用float64训练
2021/05/24 Python