Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中条件选择和循环语句使用方法介绍
Mar 13 Python
python基础入门学习笔记(Python环境搭建)
Jan 13 Python
windows10系统中安装python3.x+scrapy教程
Nov 08 Python
Python中正则表达式详解
May 17 Python
Python实现的寻找前5个默尼森数算法示例
Mar 25 Python
python 统计一个列表当中的每一个元素出现了多少次的方法
Nov 14 Python
pytorch 转换矩阵的维数位置方法
Dec 08 Python
python实现弹窗祝福效果
Apr 07 Python
解决Python对齐文本字符串问题
Aug 28 Python
pytorch实现mnist分类的示例讲解
Jan 10 Python
详解Django中views数据查询使用locals()函数进行优化
Aug 24 Python
python IP地址转整数
Nov 20 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
PHILIPS L4X25T电路分析和打理
2021/03/02 无线电
PHP 进程锁定问题分析研究
2009/11/24 PHP
PHP仿盗链代码
2012/06/03 PHP
PHP 图片水印类代码
2012/08/27 PHP
PHP数据对象PDO操作技巧小结
2016/09/27 PHP
PHP fprintf()函数用法讲解
2019/02/16 PHP
php远程请求CURL实例教程(爬虫、保存登录状态)
2020/12/10 PHP
用javascript编写的第一人称射击游戏
2007/02/25 Javascript
JavaScript中清空数组的三种方法分享
2011/04/07 Javascript
情人节专属 纯js脚本1k大小的3D玫瑰效果
2012/02/11 Javascript
JS小游戏之极速快跑源码详解
2014/09/25 Javascript
jQuery满意度星级评价插件特效代码分享
2015/08/19 Javascript
js 判断所选时间(或者当前时间)是否在某一时间段的实现代码
2015/09/05 Javascript
JavaScript实现仿淘宝商品购买数量的增减效果
2016/01/22 Javascript
JavaScript中style.left与offsetLeft的使用及区别详解
2016/06/08 Javascript
html+js+highcharts绘制圆饼图表的简单实例
2016/08/04 Javascript
浅析Jquery操作select
2016/12/13 Javascript
vue.js todolist实现代码
2017/10/29 Javascript
Vue全家桶实践项目总结(推荐)
2017/11/04 Javascript
在vue中通过axios异步使用echarts的方法
2018/01/13 Javascript
Angular 组件之间的交互的示例代码
2018/03/24 Javascript
JavaScript解决浮点数计算不准确问题的方法分析
2018/07/09 Javascript
[00:37]DOTA2上海特级锦标赛 OG战队宣传片
2016/03/03 DOTA
Python文件操作类操作实例详解
2014/07/11 Python
Python编程中的反模式实例分析
2014/12/08 Python
Python递归函数定义与用法示例
2017/06/02 Python
python2.7到3.x迁移指南
2018/02/01 Python
python numpy库linspace相同间隔采样的实现
2020/02/25 Python
Python requests.post方法中data与json参数区别详解
2020/04/30 Python
PyQt5结合matplotlib绘图的实现示例
2020/09/15 Python
基于css3 animate制作绚丽的动画效果
2015/11/24 HTML / CSS
房产授权委托书范本
2014/09/22 职场文书
离婚协议书范本及离婚须知
2014/10/15 职场文书
世界文化遗产导游词
2015/02/13 职场文书
女性励志书籍推荐
2019/08/19 职场文书
Python如何利用pandas读取csv数据并绘图
2022/07/07 Python