pytorch::Dataloader中的迭代器和生成器应用详解


Posted in Python onJanuary 03, 2020

在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader。

为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器。

这一应用场景正是python中迭代器模式的意义所在,因此本文对Dataloader中代码进行解读,可以更好的理解python中迭代器和生成器的概念。

本文的内容主要有:

  • 解释python中的迭代器和生成器概念
  • 解读pytorch中Dataloader代码,如何使用迭代器和生成器实现数据加载

python迭代基础

python中围绕着迭代有以下概念:

  1. 可迭代对象 iterables
  2. 迭代器 iterator
  3. 生成器 generator

这三个概念互相关联,并不是孤立的。在可迭代对象的基础上发展了迭代器,在迭代器的基础上又发展了生成器。

学习这些概念的名词解释没有多大意义。编程中很多的抽象概念都是为了更好的实现某些功能,才去人为创造的协议和模式。

因此,要理解它们,需要探究概念背后的逻辑,为什么这样设计?要解决的真正问题是什么?在哪些场景下应用是最好的?

迭代模式首先要解决的基础问题是,需要按一定顺序获取集合内部数据,比如循环某个list。

当数据很小时,不会有问题。但当读取大量数据时,一次性读取会超出内存限制,因此想出以下方法:

  • 把大的数据分成几个小块,分批处理
  • 惰性的取值方式,按需取值

循环读数据可分为下面三种应用场景,对应着容器(可迭代对象),迭代器和生成器:

  • for x in container: 为了遍历python内部序列容器(如list), 这些类型内部实现了__getitem__() 方法,可以从0开始按顺序遍历序列容器中的元素。
  • for x in iterator: 为了循环用户自定义的迭代器,需要实现__iter__和__next__方法,__iter__是迭代协议,具体每次迭代的执行逻辑在 __next__或next方法里
  • for x in generator: 为了节省循环的内存和加速,使用生成器来实现惰性加载,在迭代器的基础上加入了yield语句,最简单的例子是 range(5)

代码示例:

# 普通循环 for x in list
numbers = [1, 2, 3,]
for n in numbers:
  print(n) # 1,2,3

# for循环实际干的事情
# iter输入一个可迭代对象list,返回迭代器
# next方法取数据
my_iterator = iter(numbers)
next(my_iterator) # 1
next(my_iterator) # 2
next(my_iterator) # 3
next(my_iterator) # StopIteration exception

# 迭代器循环 for x in iterator
for i,n in enumerate(numbers):
  print(i,n) # 0,1 / 1,3 / 2,3

# 生成器循环 for x in generator
for i in range(3):
  print(i) # 0,1,2

上面示例代码中python内置函数iter和next的用法:

  • iter函数,调用__iter__,返回一个迭代器
  • next函数,输入迭代器,调用__next__,取出数据

比较容易混淆的是__iter__和__next__两个方法。它们的区别是:

  • __iter__是为了可以迭代,真正执行取数据的逻辑是__next__方法实现的,实际调用是通过next(iterator)完成
  • __iter__可以返回自身(return self),实际读取数据的实现放在__next__方法
  • __iter__可以和yield搭配,返回生成器对象

__iter__返回自身的做法有点类似 python中的类型系统。为了保持一致性,python中一切皆对象。

每个对象创建后,都有类型指针,而类型对象的指针指向元对象,元对象的指针指向自身。

生成器,是在__iter__方法中加入yield语句,好处有:

  • 减少循环判断逻辑的复杂度
  • 惰性取值,节省内存和时间

yield作用:

  • 代替函数中的return语句
  • 记住上一次循环迭代器内部元素的位置

三种循环模式常用函数

for x in container 方法:

  • list, deque, …
  • set, frozensets, …
  • dict, defaultdict, OrderedDict, Counter, …
  • tuple, namedtuple, …
  • str

for x in iterator 方法:

  • enumerate() # 加上list的index
  • sorted() # 排序list
  • reversed() # 倒序list
  • zip() # 合并list

for x in generator 方法:

  • range()
  • map()
  • filter()
  • reduce()
  • [x for x in list(...)]

Dataloder源码分析

pytorch采用 for x in iterator 模式,从Dataloader类中读取数据。

  1. 为了实现该迭代模式,在Dataloader内部实现__iter__方法,实际返回的是_DataLoaderIter类。
  2. _DataLoaderIter类里面,实现了 __iter__方法,返回自身,具体执行读数据的逻辑,在__next__方法中。

以下代码只截取了单线程下的数据读取。

class DataLoader(object):
  r"""
  Data loader. Combines a dataset and a sampler, and provides
  single- or multi-process iterators over the dataset.
  """
  def __init__(self, dataset, batch_size=1, shuffle=False, ...):
    self.dataset = dataset
    self.batch_sampler = batch_sampler
    ...
  
  def __iter__(self):
    return _DataLoaderIter(self)

  def __len__(self):
    return len(self.batch_sampler)

class _DataLoaderIter(object):
  r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
  def __init__(self, loader):
    self.sample_iter = iter(self.batch_sampler)
    ...

  def __next__(self):
    if self.num_workers == 0: # same-process loading
      indices = next(self.sample_iter) # may raise StopIteration
      batch = self.collate_fn([self.dataset[i] for i in indices])
      if self.pin_memory:
        batch = pin_memory_batch(batch)
      return batch
    ...

  def __iter__(self):
    return self

Dataloader类中读取数据Index的方法,采用了 for x in generator 方式,但是调用采用iter和next函数

  1. 构建随机采样类RandomSampler,内部实现了 __iter__方法
  2. __iter__方法内部使用了 yield,循环遍历数据集,当数量达到batch_size大小时,就返回
  3. 实例化随机采样类,传入iter函数,返回一个迭代器
  4. next会调用随机采样类中生成器,返回相应的index数据
class RandomSampler(object):
  """random sampler to yield a mini-batch of indices."""
  def __init__(self, batch_size, dataset, drop_last=False):
    self.dataset = dataset
    self.batch_size = batch_size
    self.num_imgs = len(dataset)
    self.drop_last = drop_last

  def __iter__(self):
    indices = np.random.permutation(self.num_imgs)
    batch = []
    for i in indices:
      batch.append(i)
      if len(batch) == self.batch_size:
        yield batch
        batch = []
    ## if images not to yield a batch
    if len(batch)>0 and not self.drop_last:
      yield batch


  def __len__(self):
    if self.drop_last:
      return self.num_imgs // self.batch_size
    else:
      return (self.num_imgs + self.batch_size - 1) // self.batch_size

batch_sampler = RandomSampler(batch_size. dataset)
sample_iter = iter(batch_sampler)
indices = next(sample_iter)

总结

本文总结了python中循环的三种模式:

  • for x in container 可迭代对象
  • for x in iterator 迭代器
  • for x in generator 生成器

pytorch中的数据加载模块 Dataloader,使用生成器来返回数据的索引,使用迭代器来返回需要的张量数据,可以在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足问题。

参考文章

迭代器和生成器
流畅的Python-第14章:可迭代的对象、迭代器和生成器
pytorch-dataloader源码

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用setup.py安装python包和卸载python包的方法
Nov 27 Python
python使用7z解压apk包的方法
Apr 18 Python
浅析Python的Django框架中的Memcached
Jul 23 Python
Python 支付整合开发包的实现
Jan 23 Python
安装docker-compose的两种最简方法
Jul 30 Python
Python求平面内点到直线距离的实现
Jan 19 Python
Python递归实现打印多重列表代码
Feb 27 Python
python print 格式化输出,动态指定长度的实现
Apr 12 Python
python 实现仿微信聊天时间格式化显示的代码
Apr 17 Python
基于python实现简单网页服务器代码实例
Sep 14 Python
聊聊Python pandas 中loc函数的使用,及跟iloc的区别说明
Mar 03 Python
Pygame Rect区域位置的使用(图文)
Nov 17 Python
django商品分类及商品数据建模实例详解
Jan 03 #Python
PyTorch和Keras计算模型参数的例子
Jan 02 #Python
Pytorch中实现只导入部分模型参数的方式
Jan 02 #Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
You might like
文件上传类
2006/10/09 PHP
用PHP制作静态网站的模板框架(三)
2006/10/09 PHP
利用PHP实现图片等比例放大和缩小的方法详解
2013/06/06 PHP
PHP批量删除、清除UTF-8文件BOM头的代码实例
2014/04/14 PHP
PHP匿名函数和use子句用法实例
2016/03/16 PHP
Yii框架中sphinx索引配置方法解析
2016/10/18 PHP
php提交表单时保留多个空格及换行的文本样式的方法
2017/06/20 PHP
简短几句jquery代码的实现一个图片向上滚动切换
2011/09/02 Javascript
js中的setInterval和setTimeout使用实例
2014/05/09 Javascript
JavaScript重载函数实例剖析
2016/05/13 Javascript
Three.js基础部分学习
2017/01/08 Javascript
Nuxt.js踩坑总结分享
2018/01/18 Javascript
vue中手机号,邮箱正则验证以及60s发送验证码的实例
2018/03/16 Javascript
Vue实现PopupWindow组件详解
2018/04/28 Javascript
js实现录音上传功能
2019/11/22 Javascript
Vue+Bootstrap收藏(点赞)功能逻辑与具体实现
2020/10/22 Javascript
vue 数据操作相关总结
2020/12/17 Vue.js
Python正则表达式匹配HTML页面编码
2015/04/08 Python
python、java等哪一门编程语言适合人工智能?
2017/11/13 Python
TensorFlow实现创建分类器
2018/02/06 Python
python RabbitMQ 使用详细介绍(小结)
2018/11/08 Python
查找适用于matplotlib的中文字体名称与实际文件名对应关系的方法
2021/01/05 Python
css3实例教程 一款纯css3实现的发光屏幕旋转特效
2014/12/07 HTML / CSS
Charlotte Tilbury美国官网:英国美妆品牌
2017/10/13 全球购物
绩效专员岗位职责
2013/12/02 职场文书
银行员工职业规划范文
2014/01/21 职场文书
舞蹈比赛获奖感言
2014/02/04 职场文书
学校联谊活动方案
2014/02/15 职场文书
后勤主管岗位职责
2014/03/01 职场文书
协议书格式
2014/04/23 职场文书
教师工作能力自我评价
2015/03/04 职场文书
2015年助理政工师工作总结
2015/05/26 职场文书
介绍信应该怎么开?
2019/04/03 职场文书
详解Oracle数据库中自带的所有表结构(sql代码)
2021/11/20 Oracle
Python docx库删除复制paragraph及行高设置图片插入示例
2022/07/23 Python
JS前端使用canvas实现扩展物体类和事件派发
2022/08/05 Javascript