Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Python中装饰器的用法
Apr 27 Python
Python实现多线程抓取妹子图
Aug 08 Python
深入解析Python编程中super关键字的用法
Jun 24 Python
Python用UUID库生成唯一ID的方法示例
Dec 15 Python
Python 列表理解及使用方法
Oct 27 Python
Python cookbook(数据结构与算法)从序列中移除重复项且保持元素间顺序不变的方法
Mar 13 Python
Python File readlines() 使用方法
Mar 19 Python
Python中矩阵创建和矩阵运算方法
Aug 04 Python
Python 实现域名解析为ip的方法
Feb 14 Python
pyqt5 使用label控件实时显示时间的实例
Jun 14 Python
python GUI库图形界面开发之PyQt5图片显示控件QPixmap详细使用方法与实例
Feb 27 Python
python实现3D地图可视化
Mar 25 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
实用PHP会员权限控制实现原理分析
2011/05/29 PHP
PHP实现显示照片exif信息的方法
2014/07/11 PHP
完美解决thinkphp验证码出错无法显示的方法
2014/12/09 PHP
php通过分类列表产生分类树数组的方法
2015/04/20 PHP
如何通过View::first使用Laravel Blade的动态模板详解
2017/09/21 PHP
thinkPHP5框架接口写法简单示例
2019/08/05 PHP
Javascript打印网页部分内容的脚本
2008/11/17 Javascript
YUI Compressor压缩JavaScript原理及微优化
2013/01/07 Javascript
javascript自启动函数的问题探讨
2013/10/05 Javascript
js实现简单的星级选择器提交效果适用于评论等
2013/10/18 Javascript
原创jQuery弹出层插件分享
2015/04/02 Javascript
10个很棒的jQuery代码片段
2015/09/24 Javascript
实现高性能JavaScript之执行与加载
2016/01/30 Javascript
Bootstrap表格和栅格分页实例详解
2016/05/20 Javascript
详解AngularJS如何实现跨域请求
2016/08/22 Javascript
JS简单实现无缝滚动效果实例
2016/08/24 Javascript
Angular2的管道Pipe的使用方法
2017/11/07 Javascript
基于JSONP原理解析(推荐)
2017/12/04 Javascript
详解基于vue-cli配置移动端自适应
2018/01/13 Javascript
通过vue写一个瀑布流插件代码实例
2019/09/07 Javascript
vue-dplayer 视频播放器实例代码
2019/11/08 Javascript
Openlayers实现图形绘制
2020/09/28 Javascript
初学Python函数的笔记整理
2015/04/07 Python
在Python中使用dict和set方法的教程
2015/04/27 Python
Python端口扫描简单程序
2016/11/10 Python
centos6.5安装python3.7.1之后无法使用pip的解决方案
2019/02/14 Python
Django ORM实现按天获取数据去重求和例子
2020/05/18 Python
Pycharm插件(Grep Console)自定义规则输出颜色日志的方法
2020/05/27 Python
5分钟让你掌握css3阴影、倒影、渐变小技巧(小编推荐)
2016/08/15 HTML / CSS
HTML5 File接口在web页面上使用文件下载
2017/02/27 HTML / CSS
印尼极简主义和实惠的在线家具店:Fabelio
2019/03/27 全球购物
高级人员简历的自我评价分享
2013/11/03 职场文书
食品厂厂长岗位职责
2014/01/30 职场文书
中国合伙人观后感
2015/06/02 职场文书
创业计划书之网吧
2019/10/10 职场文书
PHP控制循环操作的时间
2021/04/01 PHP