Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python重新引入被覆盖的自带function
Jul 16 Python
详解Python的Django框架中Manager方法的使用
Jul 21 Python
详解Python中的array数组模块相关使用
Jul 05 Python
python定向爬虫校园论坛帖子信息
Jul 23 Python
pandas把所有大于0的数设置为1的方法
Jan 26 Python
python 实现的发送邮件模板【普通邮件、带附件、带图片邮件】
Jul 06 Python
cProfile Python性能分析工具使用详解
Jul 22 Python
Python定时发送天气预报邮件代码实例
Sep 09 Python
Keras load_model 导入错误的解决方式
Jun 09 Python
python3字符串输出常见面试题总结
Dec 01 Python
Python爬虫爬取ts碎片视频+验证码登录功能
Feb 22 Python
python爬虫破解字体加密案例详解
Mar 02 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
PHP 数据库树的遍历方法
2009/02/06 PHP
php 定义404页面的实现代码
2012/11/19 PHP
php后门URL的防范
2013/11/12 PHP
PHP+Ajax实现的无刷新分页功能详解【附demo源码下载】
2017/07/03 PHP
PHP实现登录注册之BootStrap表单功能
2017/09/03 PHP
如何在PHP中使用AES加密算法加密数据
2020/06/24 PHP
jqueyr判断checkbox组的选中(示例代码)
2013/11/08 Javascript
Jquery方式获取iframe页面中的 Dom元素
2014/05/07 Javascript
JavaScript数值转换的三种方式总结
2014/07/31 Javascript
jQuery中last()方法用法实例
2015/01/06 Javascript
JavaScript实现模仿桌面窗口的方法
2015/07/18 Javascript
JavaScript的设计模式经典之建造者模式
2016/02/24 Javascript
Javascript中apply、call、bind的巧妙使用
2016/08/18 Javascript
jQuery点击弹出层弹出模态框点击模态框消失代码分享
2017/01/21 Javascript
vue中用动态组件实现选项卡切换效果
2017/03/25 Javascript
详解JS中遍历语法的比较
2017/04/07 Javascript
JS 组件系列之 bootstrap treegrid 组件封装过程
2017/04/28 Javascript
React 路由懒加载的几种实现方案
2018/10/23 Javascript
JS简单数组排序操作示例【sort方法】
2019/05/17 Javascript
axios实现简单文件上传功能
2019/09/25 Javascript
vue项目中企业微信使用js-sdk时config和agentConfig配置方式详解
2020/12/15 Vue.js
Python爬虫包BeautifulSoup简介与安装(一)
2018/06/17 Python
Django中URL的参数传递的实现
2019/08/04 Python
Flask框架学习笔记之模板操作实例详解
2019/08/15 Python
python 一篇文章搞懂装饰器所有用法(建议收藏)
2019/08/23 Python
Python中注释(多行注释和单行注释)的用法实例
2019/08/28 Python
Python实现汇率转换操作
2020/05/03 Python
纯CSS实现聊天框小尖角、气泡效果
2014/04/04 HTML / CSS
html5 canvas-1.canvas介绍(hello canvas)
2013/01/07 HTML / CSS
英国领先的奢侈品零售商之一:CRUISE
2016/12/02 全球购物
Superdry极度干燥美国官网:英国制造的服装品牌
2018/11/13 全球购物
智能家居、吸尘器、滑板车、电动自行车网上购物:Geekmaxi
2021/01/18 全球购物
《雷鸣电闪波尔卡》教学反思
2014/02/23 职场文书
小学生评语集锦
2014/04/18 职场文书
2015年网络管理员工作总结
2015/05/21 职场文书
2015年秋学期教研工作总结
2015/10/14 职场文书