pytorch::Dataloader中的迭代器和生成器应用详解


Posted in Python onJanuary 03, 2020

在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader。

为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器。

这一应用场景正是python中迭代器模式的意义所在,因此本文对Dataloader中代码进行解读,可以更好的理解python中迭代器和生成器的概念。

本文的内容主要有:

  • 解释python中的迭代器和生成器概念
  • 解读pytorch中Dataloader代码,如何使用迭代器和生成器实现数据加载

python迭代基础

python中围绕着迭代有以下概念:

  1. 可迭代对象 iterables
  2. 迭代器 iterator
  3. 生成器 generator

这三个概念互相关联,并不是孤立的。在可迭代对象的基础上发展了迭代器,在迭代器的基础上又发展了生成器。

学习这些概念的名词解释没有多大意义。编程中很多的抽象概念都是为了更好的实现某些功能,才去人为创造的协议和模式。

因此,要理解它们,需要探究概念背后的逻辑,为什么这样设计?要解决的真正问题是什么?在哪些场景下应用是最好的?

迭代模式首先要解决的基础问题是,需要按一定顺序获取集合内部数据,比如循环某个list。

当数据很小时,不会有问题。但当读取大量数据时,一次性读取会超出内存限制,因此想出以下方法:

  • 把大的数据分成几个小块,分批处理
  • 惰性的取值方式,按需取值

循环读数据可分为下面三种应用场景,对应着容器(可迭代对象),迭代器和生成器:

  • for x in container: 为了遍历python内部序列容器(如list), 这些类型内部实现了__getitem__() 方法,可以从0开始按顺序遍历序列容器中的元素。
  • for x in iterator: 为了循环用户自定义的迭代器,需要实现__iter__和__next__方法,__iter__是迭代协议,具体每次迭代的执行逻辑在 __next__或next方法里
  • for x in generator: 为了节省循环的内存和加速,使用生成器来实现惰性加载,在迭代器的基础上加入了yield语句,最简单的例子是 range(5)

代码示例:

# 普通循环 for x in list
numbers = [1, 2, 3,]
for n in numbers:
  print(n) # 1,2,3

# for循环实际干的事情
# iter输入一个可迭代对象list,返回迭代器
# next方法取数据
my_iterator = iter(numbers)
next(my_iterator) # 1
next(my_iterator) # 2
next(my_iterator) # 3
next(my_iterator) # StopIteration exception

# 迭代器循环 for x in iterator
for i,n in enumerate(numbers):
  print(i,n) # 0,1 / 1,3 / 2,3

# 生成器循环 for x in generator
for i in range(3):
  print(i) # 0,1,2

上面示例代码中python内置函数iter和next的用法:

  • iter函数,调用__iter__,返回一个迭代器
  • next函数,输入迭代器,调用__next__,取出数据

比较容易混淆的是__iter__和__next__两个方法。它们的区别是:

  • __iter__是为了可以迭代,真正执行取数据的逻辑是__next__方法实现的,实际调用是通过next(iterator)完成
  • __iter__可以返回自身(return self),实际读取数据的实现放在__next__方法
  • __iter__可以和yield搭配,返回生成器对象

__iter__返回自身的做法有点类似 python中的类型系统。为了保持一致性,python中一切皆对象。

每个对象创建后,都有类型指针,而类型对象的指针指向元对象,元对象的指针指向自身。

生成器,是在__iter__方法中加入yield语句,好处有:

  • 减少循环判断逻辑的复杂度
  • 惰性取值,节省内存和时间

yield作用:

  • 代替函数中的return语句
  • 记住上一次循环迭代器内部元素的位置

三种循环模式常用函数

for x in container 方法:

  • list, deque, …
  • set, frozensets, …
  • dict, defaultdict, OrderedDict, Counter, …
  • tuple, namedtuple, …
  • str

for x in iterator 方法:

  • enumerate() # 加上list的index
  • sorted() # 排序list
  • reversed() # 倒序list
  • zip() # 合并list

for x in generator 方法:

  • range()
  • map()
  • filter()
  • reduce()
  • [x for x in list(...)]

Dataloder源码分析

pytorch采用 for x in iterator 模式,从Dataloader类中读取数据。

  1. 为了实现该迭代模式,在Dataloader内部实现__iter__方法,实际返回的是_DataLoaderIter类。
  2. _DataLoaderIter类里面,实现了 __iter__方法,返回自身,具体执行读数据的逻辑,在__next__方法中。

以下代码只截取了单线程下的数据读取。

class DataLoader(object):
  r"""
  Data loader. Combines a dataset and a sampler, and provides
  single- or multi-process iterators over the dataset.
  """
  def __init__(self, dataset, batch_size=1, shuffle=False, ...):
    self.dataset = dataset
    self.batch_sampler = batch_sampler
    ...
  
  def __iter__(self):
    return _DataLoaderIter(self)

  def __len__(self):
    return len(self.batch_sampler)

class _DataLoaderIter(object):
  r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
  def __init__(self, loader):
    self.sample_iter = iter(self.batch_sampler)
    ...

  def __next__(self):
    if self.num_workers == 0: # same-process loading
      indices = next(self.sample_iter) # may raise StopIteration
      batch = self.collate_fn([self.dataset[i] for i in indices])
      if self.pin_memory:
        batch = pin_memory_batch(batch)
      return batch
    ...

  def __iter__(self):
    return self

Dataloader类中读取数据Index的方法,采用了 for x in generator 方式,但是调用采用iter和next函数

  1. 构建随机采样类RandomSampler,内部实现了 __iter__方法
  2. __iter__方法内部使用了 yield,循环遍历数据集,当数量达到batch_size大小时,就返回
  3. 实例化随机采样类,传入iter函数,返回一个迭代器
  4. next会调用随机采样类中生成器,返回相应的index数据
class RandomSampler(object):
  """random sampler to yield a mini-batch of indices."""
  def __init__(self, batch_size, dataset, drop_last=False):
    self.dataset = dataset
    self.batch_size = batch_size
    self.num_imgs = len(dataset)
    self.drop_last = drop_last

  def __iter__(self):
    indices = np.random.permutation(self.num_imgs)
    batch = []
    for i in indices:
      batch.append(i)
      if len(batch) == self.batch_size:
        yield batch
        batch = []
    ## if images not to yield a batch
    if len(batch)>0 and not self.drop_last:
      yield batch


  def __len__(self):
    if self.drop_last:
      return self.num_imgs // self.batch_size
    else:
      return (self.num_imgs + self.batch_size - 1) // self.batch_size

batch_sampler = RandomSampler(batch_size. dataset)
sample_iter = iter(batch_sampler)
indices = next(sample_iter)

总结

本文总结了python中循环的三种模式:

  • for x in container 可迭代对象
  • for x in iterator 迭代器
  • for x in generator 生成器

pytorch中的数据加载模块 Dataloader,使用生成器来返回数据的索引,使用迭代器来返回需要的张量数据,可以在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足问题。

参考文章

迭代器和生成器
流畅的Python-第14章:可迭代的对象、迭代器和生成器
pytorch-dataloader源码

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 流程控制实例代码
Sep 25 Python
详解Python中用于计算指数的exp()方法
May 14 Python
Python保存MongoDB上的文件到本地的方法
Mar 16 Python
python正则表达式re之compile函数解析
Oct 25 Python
如何使用VSCode愉快的写Python于调试配置步骤
Apr 06 Python
python实现时间o(1)的最小栈的实例代码
Jul 23 Python
Django自定义模板过滤器和标签的实现方法
Aug 21 Python
redis数据库及与python交互用法简单示例
Nov 01 Python
Python使用urllib模块对URL网址中的中文编码与解码实例详解
Feb 18 Python
python实现扫雷游戏的示例
Oct 20 Python
Python 实现RSA加解密文本文件
Dec 30 Python
如何利用Matlab制作一款真正的拼图小游戏
May 11 Python
django商品分类及商品数据建模实例详解
Jan 03 #Python
PyTorch和Keras计算模型参数的例子
Jan 02 #Python
Pytorch中实现只导入部分模型参数的方式
Jan 02 #Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
You might like
PHP 中执行排序与 MySQL 中排序
2009/04/21 PHP
PHP管理内存函数 memory_get_usage()使用介绍
2012/09/23 PHP
php面向对象中static静态属性与方法的内存位置分析
2015/02/08 PHP
php 无限分类 树形数据格式化代码
2016/10/11 PHP
JavaScript设计模式之抽象工厂模式介绍
2014/12/28 Javascript
jquery.mobile 共同布局遇到的问题小结
2015/02/10 Javascript
js鼠标单击和双击事件冲突问题的快速解决方法
2016/07/11 Javascript
JS实现重新加载当前页面
2016/11/29 Javascript
vue + socket.io实现一个简易聊天室示例代码
2017/03/06 Javascript
javascript过滤数组重复元素的实现方法
2017/05/03 Javascript
jQuery Position方法使用和兼容性
2017/08/23 jQuery
详解extract-text-webpack-plugin 的使用及安装
2018/06/12 Javascript
使用vue.js在页面内组件监听scroll事件的方法
2018/09/11 Javascript
如何用原生js写一个弹窗消息提醒插件
2019/05/24 Javascript
vue 强制组件重新渲染(重置)的两种方案
2019/10/29 Javascript
js回调函数仿360开机
2019/12/26 Javascript
Node.js API详解之 readline模块用法详解
2020/05/22 Javascript
[01:27:43]VGJ.S vs TNC Supermajor 败者组 BO3 第三场 6.6
2018/06/07 DOTA
[01:03:41]DOTA2-DPC中国联赛 正赛 Dynasty vs XG BO3 第三场 2月2日
2021/03/11 DOTA
Python生成随机MAC地址
2015/03/10 Python
python threading模块操作多线程介绍
2015/04/08 Python
Python编程实现双击更新所有已安装python模块的方法
2017/06/05 Python
Python + selenium + requests实现12306全自动抢票及验证码破解加自动点击功能
2018/11/23 Python
pygame实现成语填空游戏
2019/10/29 Python
python global和nonlocal用法解析
2020/02/03 Python
python基于selenium爬取斗鱼弹幕
2021/02/20 Python
Python爬取你好李焕英豆瓣短评生成词云的示例代码
2021/02/24 Python
用canvas画心电图的示例代码
2018/09/10 HTML / CSS
Vilebrequin欧洲官网:法国豪华泳装品牌(男士沙滩裤)
2018/04/14 全球购物
LUISAVIAROMA中国官网:时尚奢侈品牌购物网站
2020/11/01 全球购物
奥巴马开学演讲稿
2014/05/15 职场文书
学校领导班子对照检查材料
2014/09/24 职场文书
音乐剧猫观后感
2015/06/04 职场文书
JavaScript执行机制详细介绍
2021/12/06 Javascript
Python socket如何解析HTTP请求内容
2022/02/12 Python
go goth封装第三方认证库示例详解
2022/08/14 Golang