一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系


Posted in Python onJuly 03, 2020

以下内容都是针对Pytorch 1.0-1.1介绍。

很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍。

自上而下理解三者关系

首先我们看一下DataLoader.next的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。

class DataLoader(object):
	...
	
 def __next__(self):
  if self.num_workers == 0: 
   indices = next(self.sample_iter) # Sampler
   batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
   if self.pin_memory:
    batch = _utils.pin_memory.pin_memory_batch(batch)
   return batch

在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。

那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。

再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

综上可以知道DataLoader,Sampler和Dataset三者关系如下:

一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

在阅读后文的过程中,你始终需要将上面的关系记在心里,这样能帮助你更好地理解。

Sampler

参数传递

要更加细致地理解Sampler原理,我们需要先阅读一下DataLoader 的源代码,如下:

class DataLoader(object):
 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
     batch_sampler=None, num_workers=0, collate_fn=default_collate,
     pin_memory=False, drop_last=False, timeout=0,
     worker_init_fn=None)

可以看到初始化参数里有两种sampler:samplerbatch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。例如下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分组。

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Pytorch中已经实现的Sampler有如下几种:

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读源码更深地理解,这里只做总结:

  • 如果你自定义了batch_sampler,那么这些参数都必须使用默认值:batch_size, shuffle,sampler,drop_last.
  • 如果你自定义了sampler,那么shuffle需要设置为False
  • 如果sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
    • 若shuffle=True,则sampler=RandomSampler(dataset)
    • 若shuffle=False,则sampler=SequentialSampler(dataset)

如何自定义Sampler和BatchSampler?

仔细查看源代码其实可以发现,所有采样器其实都继承自同一个父类,即Sampler,其代码定义如下:

class Sampler(object):
 r"""Base class for all Samplers.
 Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
 way to iterate over indices of dataset elements, and a :meth:`__len__` method
 that returns the length of the returned iterators.
 .. note:: The :meth:`__len__` method isn't strictly required by
    :class:`~torch.utils.data.DataLoader`, but is expected in any
    calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
 """

 def __init__(self, data_source):
  pass

 def __iter__(self):
  raise NotImplementedError
		
 def __len__(self):
  return len(self.data_source)

所以你要做的就是定义好__iter__(self)函数,不过要注意的是该函数的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))

另外BatchSampler与其他Sampler的主要区别是它需要将Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表。也就是说在后面的读取数据过程中使用的都是batch sampler。

Dataset

Dataset定义方式如下:

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...

上面三个方法是最基本的,其中__getitem__是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]来访问第一个数据。在此之前我一直没弄清楚__getitem__是什么作用,所以一直不知道该怎么进入到这个函数进行调试。现在如果你想对__getitem__方法进行调试,你可以写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:

class DataLoader(object): 
 ... 
  
 def __next__(self): 
  if self.num_workers == 0: 
   indices = next(self.sample_iter) 
   batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
   if self.pin_memory: 
    batch = _utils.pin_memory.pin_memory_batch(batch) 
   return batch

我们仔细看可以发现,前面还有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前我们需要知道每个参数的意义:

  • indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
  • self.dataset[i]: 前面已经介绍了,这里就是对第i个数据进行读取操作,一般来说self.dataset[i]=(img, label)

看到这不难猜出collate_fn的作用就是将一个batch的数据进行合并操作。默认的collate_fn是将img和label分别合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。

到此这篇关于一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系的文章就介绍到这了,更多相关Pytorch DataLoader DataSet Sampler内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python实现将绝对URL替换成相对URL的方法
Jun 28 Python
深入浅析Python中join 和 split详解(推荐)
Jun 30 Python
python制作小说爬虫实录
Aug 14 Python
利用Python暴力破解zip文件口令的方法详解
Dec 21 Python
Python使用matplotlib模块绘制图像并设置标题与坐标轴等信息示例
May 04 Python
python输出决策树图形的例子
Aug 09 Python
Python面向对象中类(class)的简单理解与用法分析
Feb 21 Python
Python开发之身份证验证库id_validator验证身份证号合法性及根据身份证号返回住址年龄等信息
Mar 20 Python
在Windows上安装和配置 Jupyter Lab 作为桌面级应用程序教程
Apr 22 Python
python如何保存文本文件
Jun 07 Python
解决pyinstaller 打包exe文件太大,用pipenv 缩小exe的问题
Jul 13 Python
详解基于python的图像Gabor变换及特征提取
Oct 26 Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
You might like
discuz程序的PHP加密函数原理分析
2011/08/05 PHP
php命令行(cli)下执行PHP脚本文件的相对路径的问题解决方法
2015/05/25 PHP
PHP实现基于栈的后缀表达式求值功能
2017/11/10 PHP
javascript 事件绑定问题
2011/01/01 Javascript
基于jquery的禁用右键、文本选择功能、复制按键的实现代码
2013/08/27 Javascript
Node.js中的模块机制学习笔记
2014/11/04 Javascript
jquery表单对象属性过滤选择器实例分析
2015/05/18 Javascript
微信小程序 页面跳转传参详解
2016/10/28 Javascript
微信小程序 网络API发起请求详解
2016/11/09 Javascript
推荐三款日期选择插件(My97DatePicker、jquery.datepicker、Mobiscroll)
2017/04/21 jQuery
javascript简单写的判断电话号码实例
2017/05/24 Javascript
Angular2的管道Pipe的使用方法
2017/11/07 Javascript
把vue-router和express项目部署到服务器的方法
2018/02/21 Javascript
Js经典案例的实例代码
2018/05/10 Javascript
小程序实现选择题选择效果
2018/11/04 Javascript
微信小程序实现文字无限轮播效果
2018/12/28 Javascript
关于AngularJS中几种Providers的区别总结
2020/05/17 Javascript
使用Mock.js生成前端测试数据
2020/12/13 Javascript
[02:44]重置世界,颠覆未来——DOTA2 7.23版本震撼上线
2019/12/01 DOTA
[01:14:41]DOTA2-DPC中国联赛定级赛 iG vs Magma BO3第一场 1月8日
2021/03/11 DOTA
Python  __getattr__与__setattr__使用方法
2008/09/06 Python
python简单程序读取串口信息的方法
2015/03/13 Python
介绍Python中几个常用的类方法
2015/04/08 Python
Python中模块pymysql查询结果后如何获取字段列表
2017/06/05 Python
python的scikit-learn将特征转成one-hot特征的方法
2018/07/10 Python
Python multiprocess pool模块报错pickling error问题解决方法分析
2019/03/20 Python
Python 计算任意两向量之间的夹角方法
2019/07/05 Python
Jupyter Notebook折叠输出的内容实例
2020/04/22 Python
python如何停止递归
2020/09/09 Python
YSL圣罗兰美妆英国官网:Yves Saint Laurent Beauty UK
2019/08/03 全球购物
校本教研工作方案
2014/01/14 职场文书
节约电力资源的建议书
2014/03/12 职场文书
党员干部形式主义个人整改措施
2014/09/17 职场文书
构建和谐校园倡议书
2015/01/19 职场文书
Python中Matplotlib的点、线形状、颜色以及绘制散点图
2022/04/07 Python
python 学习GCN图卷积神经网络
2022/05/11 Python