Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 分析Nginx访问日志并保存到MySQL数据库实例
Mar 13 Python
Python新手实现2048小游戏
Mar 31 Python
用Python制作简单的朴素基数估计器的教程
Apr 01 Python
Python的Flask框架应用调用Redis队列数据的方法
Jun 06 Python
Python增量循环删除MySQL表数据的方法
Sep 23 Python
python django 实现验证码的功能实例代码
May 18 Python
Django跨域请求CSRF的方法示例
Nov 11 Python
利用Pyhton中的requests包进行网页访问测试的方法
Dec 26 Python
python dlib人脸识别代码实例
Apr 04 Python
python shell命令行中import多层目录下的模块操作
Mar 09 Python
pycharm debug 断点调试心得分享
Apr 16 Python
使用Python解决图表与画布的间距问题
Apr 11 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
php在文件指定行中写入代码的方法
2012/05/23 PHP
php获取文件名后缀常用方法小结
2015/02/24 PHP
PHP实现随机生成水印图片功能
2017/03/22 PHP
Yii2使用表单上传文件的实例代码
2017/08/03 PHP
javascript模仿msgbox提示效果代码
2008/06/10 Javascript
JQuery获取元素文档大小、偏移和位置和滚动条位置的方法集合
2010/01/12 Javascript
基于jquery的一个图片hover的插件
2010/04/24 Javascript
Javascript中的delete介绍
2012/09/02 Javascript
基于jquery固定于顶部的导航响应浏览器滚动条事件
2014/11/02 Javascript
理解javascript封装
2016/02/23 Javascript
JS中多步骤多分步的StepJump组件实例详解
2016/04/01 Javascript
javascript小数精度丢失的完美解决方法
2016/05/31 Javascript
BootStrap 轮播插件(carousel)支持左右手势滑动的方法(三种)
2016/07/07 Javascript
JavaScript构建自己的对象示例
2016/11/29 Javascript
JavaScript奇技淫巧44招【实用】
2016/12/11 Javascript
Three.js利用Detector.js插件如何实现兼容性检测详解
2017/09/26 Javascript
在 Angular中 使用 Lodash 的方法
2018/02/11 Javascript
详解基于Vue2.0实现的移动端弹窗(Alert, Confirm, Toast)组件
2018/08/02 Javascript
VScode格式化ESlint方法(最全最好用方法)
2019/09/10 Javascript
Vue实现腾讯云点播视频上传功能的实现代码
2020/08/17 Javascript
JavaScript 中的六种循环方法
2021/01/06 Javascript
[52:40]完美世界DOTA2联赛PWL S2 Magma vs GXR 第一场 11.29
2020/12/02 DOTA
Python中的浮点数原理与运算分析
2017/10/12 Python
Python使用 Beanstalkd 做异步任务处理的方法
2018/04/24 Python
pycharm下查看python的变量类型和变量内容的方法
2018/06/26 Python
基于python实现学生管理系统
2018/10/17 Python
python调用外部程序的实操步骤
2019/03/04 Python
python 字符串常用函数详解
2019/09/11 Python
TensorFlow实现自定义Op方式
2020/02/04 Python
python读取excel数据并且画图的实现示例
2021/02/08 Python
CSS3之边框多颜色Border-color属性使用示例
2013/10/11 HTML / CSS
谈一谈HTML5本地存储技术
2016/03/02 HTML / CSS
大客户销售经理职责
2013/12/04 职场文书
总经理职责
2013/12/22 职场文书
MySQL快速插入一亿测试数据
2021/06/23 MySQL
MySQL数据库配置信息查看与修改方法详解
2022/06/25 MySQL