pytorch::Dataloader中的迭代器和生成器应用详解


Posted in Python onJanuary 03, 2020

在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader。

为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器。

这一应用场景正是python中迭代器模式的意义所在,因此本文对Dataloader中代码进行解读,可以更好的理解python中迭代器和生成器的概念。

本文的内容主要有:

  • 解释python中的迭代器和生成器概念
  • 解读pytorch中Dataloader代码,如何使用迭代器和生成器实现数据加载

python迭代基础

python中围绕着迭代有以下概念:

  1. 可迭代对象 iterables
  2. 迭代器 iterator
  3. 生成器 generator

这三个概念互相关联,并不是孤立的。在可迭代对象的基础上发展了迭代器,在迭代器的基础上又发展了生成器。

学习这些概念的名词解释没有多大意义。编程中很多的抽象概念都是为了更好的实现某些功能,才去人为创造的协议和模式。

因此,要理解它们,需要探究概念背后的逻辑,为什么这样设计?要解决的真正问题是什么?在哪些场景下应用是最好的?

迭代模式首先要解决的基础问题是,需要按一定顺序获取集合内部数据,比如循环某个list。

当数据很小时,不会有问题。但当读取大量数据时,一次性读取会超出内存限制,因此想出以下方法:

  • 把大的数据分成几个小块,分批处理
  • 惰性的取值方式,按需取值

循环读数据可分为下面三种应用场景,对应着容器(可迭代对象),迭代器和生成器:

  • for x in container: 为了遍历python内部序列容器(如list), 这些类型内部实现了__getitem__() 方法,可以从0开始按顺序遍历序列容器中的元素。
  • for x in iterator: 为了循环用户自定义的迭代器,需要实现__iter__和__next__方法,__iter__是迭代协议,具体每次迭代的执行逻辑在 __next__或next方法里
  • for x in generator: 为了节省循环的内存和加速,使用生成器来实现惰性加载,在迭代器的基础上加入了yield语句,最简单的例子是 range(5)

代码示例:

# 普通循环 for x in list
numbers = [1, 2, 3,]
for n in numbers:
  print(n) # 1,2,3

# for循环实际干的事情
# iter输入一个可迭代对象list,返回迭代器
# next方法取数据
my_iterator = iter(numbers)
next(my_iterator) # 1
next(my_iterator) # 2
next(my_iterator) # 3
next(my_iterator) # StopIteration exception

# 迭代器循环 for x in iterator
for i,n in enumerate(numbers):
  print(i,n) # 0,1 / 1,3 / 2,3

# 生成器循环 for x in generator
for i in range(3):
  print(i) # 0,1,2

上面示例代码中python内置函数iter和next的用法:

  • iter函数,调用__iter__,返回一个迭代器
  • next函数,输入迭代器,调用__next__,取出数据

比较容易混淆的是__iter__和__next__两个方法。它们的区别是:

  • __iter__是为了可以迭代,真正执行取数据的逻辑是__next__方法实现的,实际调用是通过next(iterator)完成
  • __iter__可以返回自身(return self),实际读取数据的实现放在__next__方法
  • __iter__可以和yield搭配,返回生成器对象

__iter__返回自身的做法有点类似 python中的类型系统。为了保持一致性,python中一切皆对象。

每个对象创建后,都有类型指针,而类型对象的指针指向元对象,元对象的指针指向自身。

生成器,是在__iter__方法中加入yield语句,好处有:

  • 减少循环判断逻辑的复杂度
  • 惰性取值,节省内存和时间

yield作用:

  • 代替函数中的return语句
  • 记住上一次循环迭代器内部元素的位置

三种循环模式常用函数

for x in container 方法:

  • list, deque, …
  • set, frozensets, …
  • dict, defaultdict, OrderedDict, Counter, …
  • tuple, namedtuple, …
  • str

for x in iterator 方法:

  • enumerate() # 加上list的index
  • sorted() # 排序list
  • reversed() # 倒序list
  • zip() # 合并list

for x in generator 方法:

  • range()
  • map()
  • filter()
  • reduce()
  • [x for x in list(...)]

Dataloder源码分析

pytorch采用 for x in iterator 模式,从Dataloader类中读取数据。

  1. 为了实现该迭代模式,在Dataloader内部实现__iter__方法,实际返回的是_DataLoaderIter类。
  2. _DataLoaderIter类里面,实现了 __iter__方法,返回自身,具体执行读数据的逻辑,在__next__方法中。

以下代码只截取了单线程下的数据读取。

class DataLoader(object):
  r"""
  Data loader. Combines a dataset and a sampler, and provides
  single- or multi-process iterators over the dataset.
  """
  def __init__(self, dataset, batch_size=1, shuffle=False, ...):
    self.dataset = dataset
    self.batch_sampler = batch_sampler
    ...
  
  def __iter__(self):
    return _DataLoaderIter(self)

  def __len__(self):
    return len(self.batch_sampler)

class _DataLoaderIter(object):
  r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
  def __init__(self, loader):
    self.sample_iter = iter(self.batch_sampler)
    ...

  def __next__(self):
    if self.num_workers == 0: # same-process loading
      indices = next(self.sample_iter) # may raise StopIteration
      batch = self.collate_fn([self.dataset[i] for i in indices])
      if self.pin_memory:
        batch = pin_memory_batch(batch)
      return batch
    ...

  def __iter__(self):
    return self

Dataloader类中读取数据Index的方法,采用了 for x in generator 方式,但是调用采用iter和next函数

  1. 构建随机采样类RandomSampler,内部实现了 __iter__方法
  2. __iter__方法内部使用了 yield,循环遍历数据集,当数量达到batch_size大小时,就返回
  3. 实例化随机采样类,传入iter函数,返回一个迭代器
  4. next会调用随机采样类中生成器,返回相应的index数据
class RandomSampler(object):
  """random sampler to yield a mini-batch of indices."""
  def __init__(self, batch_size, dataset, drop_last=False):
    self.dataset = dataset
    self.batch_size = batch_size
    self.num_imgs = len(dataset)
    self.drop_last = drop_last

  def __iter__(self):
    indices = np.random.permutation(self.num_imgs)
    batch = []
    for i in indices:
      batch.append(i)
      if len(batch) == self.batch_size:
        yield batch
        batch = []
    ## if images not to yield a batch
    if len(batch)>0 and not self.drop_last:
      yield batch


  def __len__(self):
    if self.drop_last:
      return self.num_imgs // self.batch_size
    else:
      return (self.num_imgs + self.batch_size - 1) // self.batch_size

batch_sampler = RandomSampler(batch_size. dataset)
sample_iter = iter(batch_sampler)
indices = next(sample_iter)

总结

本文总结了python中循环的三种模式:

  • for x in container 可迭代对象
  • for x in iterator 迭代器
  • for x in generator 生成器

pytorch中的数据加载模块 Dataloader,使用生成器来返回数据的索引,使用迭代器来返回需要的张量数据,可以在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足问题。

参考文章

迭代器和生成器
流畅的Python-第14章:可迭代的对象、迭代器和生成器
pytorch-dataloader源码

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作MongoDB基础知识
Nov 01 Python
python批量制作雷达图的实现方法
Jul 26 Python
python删除本地夹里重复文件的方法
Nov 19 Python
对python遍历文件夹中的所有jpg文件的实例详解
Dec 08 Python
强悍的Python读取大文件的解决方案
Feb 16 Python
详解Python 函数如何重载?
Apr 23 Python
Python变量访问权限控制详解
Jun 29 Python
django框架forms组件用法实例详解
Dec 10 Python
浅谈spring boot 集成 log4j 解决与logback冲突的问题
Feb 20 Python
Django使用list对单个或者多个字段求values值实例
Mar 31 Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 Python
Pycharm连接远程服务器并远程调试的全过程
Jun 24 Python
django商品分类及商品数据建模实例详解
Jan 03 #Python
PyTorch和Keras计算模型参数的例子
Jan 02 #Python
Pytorch中实现只导入部分模型参数的方式
Jan 02 #Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
You might like
php设计模式 Visitor 访问者模式
2011/06/28 PHP
PHP Cookie的使用教程详解
2013/06/03 PHP
win平台安装配置Nginx+php+mysql 环境
2016/01/12 PHP
微信小程序发送订阅消息的方法(php 为例)
2019/10/30 PHP
JS 显示当前日期与时间的代码
2010/03/24 Javascript
brook javascript框架介绍
2011/10/10 Javascript
非常好用的JsonToString 方法 简单实例
2013/07/18 Javascript
jquery实现html页面 div 假分页有原理有代码
2014/09/06 Javascript
javascript实现给定半径求出圆的面积
2015/06/26 Javascript
JavaScript让Textarea支持tab按键的方法
2015/06/26 Javascript
jQuery超精致图片轮播幻灯片特效代码分享
2015/09/10 Javascript
javascript每日必学之基础入门
2016/02/16 Javascript
如何提高Dom访问速度
2017/01/05 Javascript
浅谈js中function的参数默认值
2017/02/20 Javascript
js数字滑动时钟的简单实现(示例讲解)
2017/08/14 Javascript
使用Bootstrap + Vue.js实现表格的动态展示、新增和删除功能
2017/11/27 Javascript
webstorm+vue初始化项目的方法
2018/10/18 Javascript
node实现简单的增删改查接口实例代码
2019/08/22 Javascript
JS实现移动端在线签协议功能
2019/08/22 Javascript
Ant design vue table 单击行选中 勾选checkbox教程
2020/10/24 Javascript
ant design vue导航菜单与路由配置操作
2020/10/28 Javascript
简单说明Python中的装饰器的用法
2015/04/24 Python
利用pyinstaller或virtualenv将python程序打包详解
2017/03/22 Python
python不换行之end=与逗号的意思及用途
2017/11/21 Python
分析运行中的 Python 进程详细解析
2019/06/22 Python
python批量修改图片尺寸,并保存指定路径的实现方法
2019/07/04 Python
python基于json文件实现的gearman任务自动重启代码实例
2019/08/13 Python
python numpy 反转 reverse示例
2019/12/04 Python
利用python实现.dcm格式图像转为.jpg格式
2020/01/13 Python
如何使用python代码操作git代码
2020/02/29 Python
英国领先的男士服装和时尚零售商:Burton
2017/01/09 全球购物
比利时香水网上商店:NOTINO
2018/03/28 全球购物
德国前卫设计师时装在线商店:Luxury Loft
2019/11/04 全球购物
Tomcat Mysql datasource数据源配置
2015/12/28 面试题
司法局群众路线教育实践活动整改措施
2014/09/17 职场文书
七夕情人节问候语
2015/11/11 职场文书