pytorch::Dataloader中的迭代器和生成器应用详解


Posted in Python onJanuary 03, 2020

在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader。

为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器。

这一应用场景正是python中迭代器模式的意义所在,因此本文对Dataloader中代码进行解读,可以更好的理解python中迭代器和生成器的概念。

本文的内容主要有:

  • 解释python中的迭代器和生成器概念
  • 解读pytorch中Dataloader代码,如何使用迭代器和生成器实现数据加载

python迭代基础

python中围绕着迭代有以下概念:

  1. 可迭代对象 iterables
  2. 迭代器 iterator
  3. 生成器 generator

这三个概念互相关联,并不是孤立的。在可迭代对象的基础上发展了迭代器,在迭代器的基础上又发展了生成器。

学习这些概念的名词解释没有多大意义。编程中很多的抽象概念都是为了更好的实现某些功能,才去人为创造的协议和模式。

因此,要理解它们,需要探究概念背后的逻辑,为什么这样设计?要解决的真正问题是什么?在哪些场景下应用是最好的?

迭代模式首先要解决的基础问题是,需要按一定顺序获取集合内部数据,比如循环某个list。

当数据很小时,不会有问题。但当读取大量数据时,一次性读取会超出内存限制,因此想出以下方法:

  • 把大的数据分成几个小块,分批处理
  • 惰性的取值方式,按需取值

循环读数据可分为下面三种应用场景,对应着容器(可迭代对象),迭代器和生成器:

  • for x in container: 为了遍历python内部序列容器(如list), 这些类型内部实现了__getitem__() 方法,可以从0开始按顺序遍历序列容器中的元素。
  • for x in iterator: 为了循环用户自定义的迭代器,需要实现__iter__和__next__方法,__iter__是迭代协议,具体每次迭代的执行逻辑在 __next__或next方法里
  • for x in generator: 为了节省循环的内存和加速,使用生成器来实现惰性加载,在迭代器的基础上加入了yield语句,最简单的例子是 range(5)

代码示例:

# 普通循环 for x in list
numbers = [1, 2, 3,]
for n in numbers:
  print(n) # 1,2,3

# for循环实际干的事情
# iter输入一个可迭代对象list,返回迭代器
# next方法取数据
my_iterator = iter(numbers)
next(my_iterator) # 1
next(my_iterator) # 2
next(my_iterator) # 3
next(my_iterator) # StopIteration exception

# 迭代器循环 for x in iterator
for i,n in enumerate(numbers):
  print(i,n) # 0,1 / 1,3 / 2,3

# 生成器循环 for x in generator
for i in range(3):
  print(i) # 0,1,2

上面示例代码中python内置函数iter和next的用法:

  • iter函数,调用__iter__,返回一个迭代器
  • next函数,输入迭代器,调用__next__,取出数据

比较容易混淆的是__iter__和__next__两个方法。它们的区别是:

  • __iter__是为了可以迭代,真正执行取数据的逻辑是__next__方法实现的,实际调用是通过next(iterator)完成
  • __iter__可以返回自身(return self),实际读取数据的实现放在__next__方法
  • __iter__可以和yield搭配,返回生成器对象

__iter__返回自身的做法有点类似 python中的类型系统。为了保持一致性,python中一切皆对象。

每个对象创建后,都有类型指针,而类型对象的指针指向元对象,元对象的指针指向自身。

生成器,是在__iter__方法中加入yield语句,好处有:

  • 减少循环判断逻辑的复杂度
  • 惰性取值,节省内存和时间

yield作用:

  • 代替函数中的return语句
  • 记住上一次循环迭代器内部元素的位置

三种循环模式常用函数

for x in container 方法:

  • list, deque, …
  • set, frozensets, …
  • dict, defaultdict, OrderedDict, Counter, …
  • tuple, namedtuple, …
  • str

for x in iterator 方法:

  • enumerate() # 加上list的index
  • sorted() # 排序list
  • reversed() # 倒序list
  • zip() # 合并list

for x in generator 方法:

  • range()
  • map()
  • filter()
  • reduce()
  • [x for x in list(...)]

Dataloder源码分析

pytorch采用 for x in iterator 模式,从Dataloader类中读取数据。

  1. 为了实现该迭代模式,在Dataloader内部实现__iter__方法,实际返回的是_DataLoaderIter类。
  2. _DataLoaderIter类里面,实现了 __iter__方法,返回自身,具体执行读数据的逻辑,在__next__方法中。

以下代码只截取了单线程下的数据读取。

class DataLoader(object):
  r"""
  Data loader. Combines a dataset and a sampler, and provides
  single- or multi-process iterators over the dataset.
  """
  def __init__(self, dataset, batch_size=1, shuffle=False, ...):
    self.dataset = dataset
    self.batch_sampler = batch_sampler
    ...
  
  def __iter__(self):
    return _DataLoaderIter(self)

  def __len__(self):
    return len(self.batch_sampler)

class _DataLoaderIter(object):
  r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
  def __init__(self, loader):
    self.sample_iter = iter(self.batch_sampler)
    ...

  def __next__(self):
    if self.num_workers == 0: # same-process loading
      indices = next(self.sample_iter) # may raise StopIteration
      batch = self.collate_fn([self.dataset[i] for i in indices])
      if self.pin_memory:
        batch = pin_memory_batch(batch)
      return batch
    ...

  def __iter__(self):
    return self

Dataloader类中读取数据Index的方法,采用了 for x in generator 方式,但是调用采用iter和next函数

  1. 构建随机采样类RandomSampler,内部实现了 __iter__方法
  2. __iter__方法内部使用了 yield,循环遍历数据集,当数量达到batch_size大小时,就返回
  3. 实例化随机采样类,传入iter函数,返回一个迭代器
  4. next会调用随机采样类中生成器,返回相应的index数据
class RandomSampler(object):
  """random sampler to yield a mini-batch of indices."""
  def __init__(self, batch_size, dataset, drop_last=False):
    self.dataset = dataset
    self.batch_size = batch_size
    self.num_imgs = len(dataset)
    self.drop_last = drop_last

  def __iter__(self):
    indices = np.random.permutation(self.num_imgs)
    batch = []
    for i in indices:
      batch.append(i)
      if len(batch) == self.batch_size:
        yield batch
        batch = []
    ## if images not to yield a batch
    if len(batch)>0 and not self.drop_last:
      yield batch


  def __len__(self):
    if self.drop_last:
      return self.num_imgs // self.batch_size
    else:
      return (self.num_imgs + self.batch_size - 1) // self.batch_size

batch_sampler = RandomSampler(batch_size. dataset)
sample_iter = iter(batch_sampler)
indices = next(sample_iter)

总结

本文总结了python中循环的三种模式:

  • for x in container 可迭代对象
  • for x in iterator 迭代器
  • for x in generator 生成器

pytorch中的数据加载模块 Dataloader,使用生成器来返回数据的索引,使用迭代器来返回需要的张量数据,可以在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足问题。

参考文章

迭代器和生成器
流畅的Python-第14章:可迭代的对象、迭代器和生成器
pytorch-dataloader源码

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的条件判断语句基础学习教程
Feb 07 Python
用Python将IP地址在整型和字符串之间轻松转换
Mar 22 Python
flask + pymysql操作Mysql数据库的实例
Nov 13 Python
python书籍信息爬虫实例
Mar 19 Python
TensorFlow入门使用 tf.train.Saver()保存模型
Apr 24 Python
Python数据分析matplotlib设置多个子图的间距方法
Aug 03 Python
numpy.random模块用法总结
May 27 Python
Python批量修改图片分辨率的实例代码
Jul 04 Python
python的列表List求均值和中位数实例
Mar 03 Python
Python 多进程、多线程效率对比
Nov 19 Python
如何用python绘制雷达图
Apr 24 Python
python 实现的截屏工具
May 08 Python
django商品分类及商品数据建模实例详解
Jan 03 #Python
PyTorch和Keras计算模型参数的例子
Jan 02 #Python
Pytorch中实现只导入部分模型参数的方式
Jan 02 #Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
You might like
PHP抓屏函数实现屏幕快照代码分享
2014/01/02 PHP
成为好程序员必须避免的5个坏习惯
2014/07/04 PHP
php实现可逆加密的方法
2015/08/11 PHP
jQuery+Ajax+PHP“喜欢”评级功能实现代码
2015/10/08 PHP
Linux安装配置php环境的方法
2016/01/14 PHP
基于PHP实现用户在线状态检测
2020/11/10 PHP
js String对象中常用方法小结(字符串操作)
2012/01/27 Javascript
Jquery 数据选择插件Pickerbox使用介绍
2012/08/24 Javascript
jquery可见性过滤选择器使用示例
2013/06/24 Javascript
jquery 获取dom固定元素 添加样式的简单实例
2014/02/04 Javascript
javascript消除window.close()的提示窗口
2015/05/20 Javascript
轻松学习jQuery插件EasyUI EasyUI实现拖放商品放置购物车
2015/11/30 Javascript
理解JavaScript表单的基础知识
2016/01/25 Javascript
ionic隐藏tabs的方法
2016/08/29 Javascript
基于jQuery实现顶部导航栏功能
2016/12/27 Javascript
uploader秒传图片到服务器完整代码
2017/04/22 Javascript
javascript function(函数类型)使用与注意事项小结
2019/06/10 Javascript
layui.tree组件的使用以及搜索节点功能的实现
2019/09/26 Javascript
Python实现将n个点均匀地分布在球面上的方法
2015/03/12 Python
Python实现文件复制删除
2016/04/19 Python
详谈python3中用for循环删除列表中元素的坑
2018/04/19 Python
深入理解Django自定义信号(signals)
2018/10/15 Python
使用Python编写Prometheus监控的方法
2018/10/15 Python
Python向excel中写入数据的方法
2019/05/05 Python
python如何以表格形式打印输出的方法示例
2019/06/21 Python
Python爬虫运用正则表达式的方法和优缺点
2019/08/25 Python
Python实现简单的2048小游戏
2021/03/01 Python
日本著名的平价时尚女性购物网站:Fifth
2016/08/24 全球购物
澳大利亚当地最大的时装生产商:Cue
2018/08/06 全球购物
Rodd & Gunn澳大利亚官网:新西兰男装品牌
2018/09/25 全球购物
房展策划方案
2014/06/07 职场文书
2014年节能减排工作总结
2014/12/06 职场文书
幼儿园庆六一主持词
2015/06/30 职场文书
联村联户简报
2015/07/21 职场文书
SpringBoot集成Redis,并自定义对象序列化操作
2021/06/22 Java/Android
Python的代理类实现,控制访问和修改属性的权限你都了解吗
2022/03/21 Python