Pytorch 数据加载与数据预处理方式


Posted in Python onDecember 31, 2019

数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况。

torchvision.datasets中的数据集

torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法。

Pytorch 数据加载与数据预处理方式

Dataset源码如上,可以看到其中包含了两个没有实现的子方法,之后所有的Dataet类都继承该类,并根据数据情况定制这两个子方法的具体实现。

因此当我们需要加载自己的数据集的时候也可以借鉴这种方法,只需要继承torch.utils.data.Dataset类并重写 init ,len,以及getitem这三个方法即可。这样组着的类可以直接作为参数传入到torch.util.data.DataLoader中去。

以CIFAR10为例 源码:

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
root (string) ? Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
train (bool, optional) ? If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) ? A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) ? A function/transform that takes in the target and transforms it.
download (bool, optional) ? If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

加载自己的数据集

对于torchvision.datasets中有两个不同的类,分别为DatasetFolder和ImageFolder,ImageFolder是继承自DatasetFolder。

下面我们通过源码来看一看folder文件中DatasetFolder和ImageFolder分别做了些什么

import torch.utils.data as data
from PIL import Image
import os
import os.path


def has_file_allowed_extension(filename, extensions): //检查输入是否是规定的扩展名
  """Checks if a file is an allowed extension.

  Args:
    filename (string): path to a file

  Returns:
    bool: True if the filename ends with a known image extension
  """
  filename_lower = filename.lower()
  return any(filename_lower.endswith(ext) for ext in extensions)


def find_classes(dir):
  classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] //获取root目录下所有的文件夹名称

  classes.sort()
  class_to_idx = {classes[i]: i for i in range(len(classes))} //生成类别名称与类别id的对应Dictionary
  return classes, class_to_idx


def make_dataset(dir, class_to_idx, extensions):
  images = []
  dir = os.path.expanduser(dir)// 将~和~user转化为用户目录,对参数中出现~进行处理
  for target in sorted(os.listdir(dir)):
    d = os.path.join(dir, target)
    if not os.path.isdir(d):
      continue

    for root, _, fnames in sorted(os.walk(d)): //os.work包含三个部分,root代表该目录路径 _代表该路径下的文件夹名称集合,fnames代表该路径下的文件名称集合
      for fname in sorted(fnames):
        if has_file_allowed_extension(fname, extensions):
          path = os.path.join(root, fname)
          item = (path, class_to_idx[target])
          images.append(item)  //生成(训练样本图像目录,训练样本所属类别)的元组

  return images  //返回上述元组的列表


class DatasetFolder(data.Dataset):
  """A generic data loader where the samples are arranged in this way: ::

    root/class_x/xxx.ext
    root/class_x/xxy.ext
    root/class_x/xxz.ext

    root/class_y/123.ext
    root/class_y/nsdf3.ext
    root/class_y/asd932_.ext

  Args:
    root (string): Root directory path.
    loader (callable): A function to load a sample given its path.
    extensions (list[string]): A list of allowed extensions.
    transform (callable, optional): A function/transform that takes in
      a sample and returns a transformed version.
      E.g, ``transforms.RandomCrop`` for images.
    target_transform (callable, optional): A function/transform that takes
      in the target and transforms it.

   Attributes:
    classes (list): List of the class names.
    class_to_idx (dict): Dict with items (class_name, class_index).
    samples (list): List of (sample path, class_index) tuples
  """

  def __init__(self, root, loader, extensions, transform=None, target_transform=None):
    classes, class_to_idx = find_classes(root)
    samples = make_dataset(root, class_to_idx, extensions)
    if len(samples) == 0:
      raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                "Supported extensions are: " + ",".join(extensions)))

    self.root = root
    self.loader = loader
    self.extensions = extensions

    self.classes = classes
    self.class_to_idx = class_to_idx
    self.samples = samples

    self.transform = transform
    self.target_transform = target_transform

  def __getitem__(self, index):
    """
    根据index获取sample 返回值为(sample,target)元组,同时如果该类输入参数中有transform和target_transform,torchvision.transforms类型的参数时,将获取的元组分别执行transform和target_transform中的数据转换方法。
       Args:
      index (int): Index

    Returns:
      tuple: (sample, target) where target is class_index of the target class.
    """
    path, target = self.samples[index]
    sample = self.loader(path)
    if self.transform is not None:
      sample = self.transform(sample)
    if self.target_transform is not None:
      target = self.target_transform(target)

    return sample, target


  def __len__(self):
    return len(self.samples)

  def __repr__(self): //定义输出对象格式 其中和__str__的区别是__repr__无论是print输出还是直接输出对象自身 都是以定义的格式进行输出,而__str__ 只有在print输出的时候会是以定义的格式进行输出
    fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
    fmt_str += '  Number of datapoints: {}\n'.format(self.__len__())
    fmt_str += '  Root Location: {}\n'.format(self.root)
    tmp = '  Transforms (if any): '
    fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    tmp = '  Target Transforms (if any): '
    fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    return fmt_str



IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']


def pil_loader(path):
  # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  with open(path, 'rb') as f:
    img = Image.open(f)
    return img.convert('RGB')


def accimage_loader(path):
  import accimage
  try:
    return accimage.Image(path)
  except IOError:
    # Potentially a decoding problem, fall back to PIL.Image
    return pil_loader(path)


def default_loader(path):
  from torchvision import get_image_backend
  if get_image_backend() == 'accimage':
    return accimage_loader(path)
  else:
    return pil_loader(path)


class ImageFolder(DatasetFolder): 
  """A generic data loader where the images are arranged in this way: ::

    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png

    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png

  Args:
    root (string): Root directory path.
    transform (callable, optional): A function/transform that takes in an PIL image
      and returns a transformed version. E.g, ``transforms.RandomCrop``
    target_transform (callable, optional): A function/transform that takes in the
      target and transforms it.
    loader (callable, optional): A function to load an image given its path.

   Attributes:
    classes (list): List of the class names.
    class_to_idx (dict): Dict with items (class_name, class_index).
    imgs (list): List of (image path, class_index) tuples
  """
  def __init__(self, root, transform=None, target_transform=None,
         loader=default_loader):
    super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                     transform=transform,
                     target_transform=target_transform)
    self.imgs = self.samples

如果自己所要加载的数据组织形式如下

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

即不同类别的训练数据分别存储在不同的文件夹中,这些文件夹都在root(即形如 D:/animals 或者 /usr/animals )路径下

class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)

参数如下:

root (string) ? Root directory path.
transform (callable, optional) ? A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) ? A function/transform that takes in the target and transforms it.
loader ? A function to load an image given its path. 就是上述源码中


__getitem__(index)
Parameters: index (int) ? Index
Returns:  (sample, target) where target is class_index of the target class.
Return type:  tuple

可以通过torchvision.datasets.ImageFolder进行加载

img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
                      transform=transforms.Compose([
                        transforms.Scale(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor()])
                      )
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))

对于所有的训练样本都在一个文件夹中 同时有一个对应的txt文件每一行分别是对应图像的路径以及其所属的类别,可以参照上述class写出对应的加载类

def default_loader(path):
  return Image.open(path).convert('RGB')


class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))

DataLoader解析

位于torch.util.data.DataLoader中 源代码

该接口的主要目的是将pytorch中已有的数据接口如torchvision.datasets.ImageFolder,或者自定义的数据读取接口转化按照

batch_size的大小封装为Tensor,即相当于在内置数据接口或者自定义数据接口的基础上增加一维,大小为batch_size的大小,

得到的数据在之后可以通过封装为Variable,作为模型的输出

_ _ init _ _中所需的参数如下

1. dataset torch.utils.data.Dataset类的子类,可以是torchvision.datasets.ImageFolder等内置类,也可是继承了torch.utils.data.Dataset的自定义类
2. batch_size 每一个batch中包含的样本个数,默认是1 
3. shuffle 一般在训练集中采用,默认是false,设置为true则每一个epoch都会将训练样本打乱
4. sampler 训练样本选取策略,和shuffle是互斥的 如果 shuffle为true,该参数一定要为None
5. batch_sampler BatchSampler 一次产生一个 batch 的 indices,和sampler以及shuffle互斥,一般使用默认的即可
  上述Sampler的源代码地址如下[源代码](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py)
6. num_workers 用于数据加载的线程数量 默认为0 即只有主线程用来加载数据
7. collate_fn 用来聚合数据生成mini_batch

使用的时候一般为如下使用方法:

train_data=torch.utils.data.DataLoader(...) 
for i, (input, target) in enumerate(train_data): 
...

循环取DataLoader中的数据会触发类中_ _ iter __方法,查看源代码可知 其中调用的方法为 return _DataLoaderIter(self),因此需要查看 DataLoaderIter 这一内部类

class DataLoaderIter(object):
  "Iterates once over the DataLoader's dataset, as specified by the sampler"

  def __init__(self, loader):
    self.dataset = loader.dataset
    self.collate_fn = loader.collate_fn
    self.batch_sampler = loader.batch_sampler
    self.num_workers = loader.num_workers
    self.pin_memory = loader.pin_memory and torch.cuda.is_available()
    self.timeout = loader.timeout
    self.done_event = threading.Event()

    self.sample_iter = iter(self.batch_sampler)

    if self.num_workers > 0:
      self.worker_init_fn = loader.worker_init_fn
      self.index_queue = multiprocessing.SimpleQueue()
      self.worker_result_queue = multiprocessing.SimpleQueue()
      self.batches_outstanding = 0
      self.worker_pids_set = False
      self.shutdown = False
      self.send_idx = 0
      self.rcvd_idx = 0
      self.reorder_dict = {}

      base_seed = torch.LongTensor(1).random_()[0]
      self.workers = [
        multiprocessing.Process(
          target=_worker_loop,
          args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
             base_seed + i, self.worker_init_fn, i))
        for i in range(self.num_workers)]

      if self.pin_memory or self.timeout > 0:
        self.data_queue = queue.Queue()
        self.worker_manager_thread = threading.Thread(
          target=_worker_manager_loop,
          args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
             torch.cuda.current_device()))
        self.worker_manager_thread.daemon = True
        self.worker_manager_thread.start()
      else:
        self.data_queue = self.worker_result_queue

      for w in self.workers:
        w.daemon = True # ensure that the worker exits on process exit
        w.start()

      _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
      _set_SIGCHLD_handler()
      self.worker_pids_set = True

      # prime the prefetch loop
      for _ in range(2 * self.num_workers):
        self._put_indices()

以上这篇Pytorch 数据加载与数据预处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django中传递参数到URLconf的视图函数中的方法
Jul 18 Python
教大家玩转Python字符串处理的七种技巧
Mar 31 Python
vue.js实现输入框输入值内容实时响应变化示例
Jul 07 Python
pandas DataFrame 删除重复的行的实现方法
Jan 29 Python
python3使用print打印带颜色的字符串代码实例
Aug 22 Python
Python动态导入模块和反射机制详解
Feb 18 Python
解决Jupyter无法导入已安装的 module问题
Apr 17 Python
python 实现读取csv数据,分类求和 再写进 csv
May 18 Python
windows10在visual studio2019下配置使用openCV4.3.0
Jul 14 Python
python常见的占位符总结及用法
Jul 02 Python
Python人工智能之混合高斯模型运动目标检测详解分析
Nov 07 Python
如何使用python包中的sched事件调度器
Apr 30 Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
Python面向对象封装操作案例详解
Dec 31 #Python
Python实现隐马尔可夫模型的前向后向算法的示例代码
Dec 31 #Python
You might like
DC的38部超级英雄动画电影
2020/03/03 欧美动漫
php实现模拟登陆方正教务系统抓取课表
2015/05/19 PHP
列表内容的选择
2006/06/30 Javascript
javascript 处理HTML元素必须避免使用的一种方法
2009/07/30 Javascript
JavaScript中的私有/静态属性介绍
2012/07/26 Javascript
JS动态添加option和删除option(附实例代码)
2013/04/01 Javascript
javascript自定义startWith()和endWith()的两种方法
2013/11/11 Javascript
node.js中的fs.fchmod方法使用说明
2014/12/16 Javascript
后台获取ZTREE选中节点的方法
2015/02/12 Javascript
JavaScript中的原型prototype属性使用详解
2015/06/05 Javascript
jquery实现的缩略图预览滑块实例
2015/06/25 Javascript
关于angularJs指令的Scope(作用域)介绍
2016/10/25 Javascript
微信小程序实战之运维小项目
2017/01/17 Javascript
基于vue.js轮播组件vue-awesome-swiper实现轮播图
2017/03/17 Javascript
vue组件三大核心概念图文详解
2019/05/30 Javascript
解决layui的radio属性或别的属性没显示出来的问题
2019/09/26 Javascript
JavaScript交换两个变量方法实例
2019/11/25 Javascript
5个你不知道的JavaScript字符串处理库(小结)
2020/06/01 Javascript
Angular利用HTTP POST下载流文件的步骤记录
2020/07/26 Javascript
vue 修改 data 数据问题并实时显示操作
2020/09/07 Javascript
PyCharm鼠标右键不显示Run unittest的解决方法
2018/11/30 Python
一篇文章搞懂Python的类与对象名称空间
2018/12/10 Python
Django如何防止定时任务并发浅析
2019/05/14 Python
Python中调用其他程序的方式详解
2019/08/06 Python
TensorFlow基本的常量、变量和运算操作详解
2020/02/03 Python
一些关于python 装饰器的个人理解
2020/08/31 Python
python 基于Apscheduler实现定时任务
2020/12/15 Python
结合CSS3的布局新特征谈谈常见布局方法
2016/01/22 HTML / CSS
html5 canvas fillRect坐标和大小的问题解决方法
2014/03/26 HTML / CSS
阿联酋航空假期:Emirates Holidays
2018/03/20 全球购物
js实现弹框效果
2021/03/24 Javascript
干部下基层实施方案
2014/03/14 职场文书
清洁工工作总结
2015/08/11 职场文书
2016年暑期见闻作文
2015/11/25 职场文书
小学大队委竞选口号
2015/12/25 职场文书
2019年共青团工作条例最新版
2019/11/12 职场文书