pytorch ImageFolder的覆写实例


Posted in Python onFebruary 20, 2020

在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

使用可见pytorch torchvision.ImageFolder的用法介绍

这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能

首先先分析下其源代码:

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):
 """A generic data loader where the images are arranged in this way: ::

  root/dog/xxx.png
  root/dog/xxy.png
  root/dog/xxz.png

  root/cat/123.png
  root/cat/nsdf3.png
  root/cat/asd932_.png

 Args:
  root (string): Root directory path.
  transform (callable, optional): A function/transform that takes in an PIL image
   and returns a transformed version. E.g, ``transforms.RandomCrop``
  target_transform (callable, optional): A function/transform that takes in the
   target and transforms it.
  loader (callable, optional): A function to load an image given its path.

  Attributes:
  classes (list): List of the class names.
  class_to_idx (dict): Dict with items (class_name, class_index).
  imgs (list): List of (image path, class_index) tuples
 """
 def __init__(self, root, transform=None, target_transform=None,
     loader=default_loader):
  super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
           transform=transform,
           target_transform=target_transform)
  self.imgs = self.samples

ImageFolder的代码很简单,主要是继承了DatasetFolder:

def has_file_allowed_extension(filename, extensions):
 """查看文件是否是支持的可扩展类型

 Args:
  filename (string): 文件路径
  extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型

 Returns:
  bool: True if the filename ends with one of given extensions
 """
 filename_lower = filename.lower()
 return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表


def make_dataset(dir, class_to_idx, extensions):
 """
  返回形如[(图像路径, 该图像对应的类别索引值),(),...]
 """
 images = []
 dir = os.path.expanduser(dir)
 for target in sorted(class_to_idx.keys()):
  d = os.path.join(dir, target)
  if not os.path.isdir(d):
   continue

  for root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名
   for fname in sorted(fnames):
    if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可扩展类型,是则继续
     path = os.path.join(root, fname)
     item = (path, class_to_idx[target])
     images.append(item)

 return images

class DatasetFolder(data.Dataset):
 """A generic data loader where the samples are arranged in this way: ::

  root/class_x/xxx.ext
  root/class_x/xxy.ext
  root/class_x/xxz.ext

  root/class_y/123.ext
  root/class_y/nsdf3.ext
  root/class_y/asd932_.ext

 Args:
  root (string): 根目录路径
  loader (callable): 根据给定的路径来加载样本的可调用函数
  extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.
  transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本
   E.g, ``transforms.RandomCrop`` for images.
  target_transform (callable, optional): 用于样本标签的transform函数

  Attributes:
  classes (list): 类别名列表
  class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}
  samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)
  targets (list): 在数据集中每张图片的类索引值,为列表
 """

 def __init__(self, root, loader, extensions, transform=None, target_transform=None):
  classes, class_to_idx = self._find_classes(root) # 得到类名和类索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}
  # 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记
  samples = make_dataset(root, class_to_idx, extensions) 
  if len(samples) == 0:
   raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
        "Supported extensions are: " + ",".join(extensions)))

  self.root = root
  self.loader = loader
  self.extensions = extensions

  self.classes = classes
  self.class_to_idx = class_to_idx
  self.samples = samples
  self.targets = [s[1] for s in samples] #所有图像的类索引值组成的列表

  self.transform = transform
  self.target_transform = target_transform

 def _find_classes(self, dir):
  """
  在数据集中查找类文件夹。

  Args:
   dir (string): 根目录路径

  Returns:
   返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': 0, 'dog': 1}.

  Ensures:
   保证没有类名是另一个类目录的子目录
  """
  if sys.version_info >= (3, 5):
   # Faster and available in Python 3.5 and above
   classes = [d.name for d in os.scandir(dir) if d.is_dir()] #获得根目录dir的所有第一层子目录名
  else:
   classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一样,只是版本不同方法不同
  classes.sort() #然后对类名进行排序
  class_to_idx = {classes[i]: i for i in range(len(classes))} #然后将类名和索引值一一对应的到相应字典,如{'cat': 0, 'dog': 1}
  return classes, class_to_idx #然后返回类名和类索引

 def __getitem__(self, index):
  """
  Args:
   index (int): Index

  Returns:
   tuple: (sample, target) where target is class_index of the target class.
  """
  path, target = self.samples[index]
  sample = self.loader(path) # 加载图片
  if self.transform is not None:
   sample = self.transform(sample)
  if self.target_transform is not None:
   target = self.target_transform(target)

  return sample, target

 def __len__(self):
  return len(self.samples)

 def __repr__(self):
  fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
  fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
  fmt_str += ' Root Location: {}\n'.format(self.root)
  tmp = ' Transforms (if any): '
  fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  tmp = ' Target Transforms (if any): '
  fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  return fmt_str

此时想要覆写ImageFolder,代码为:

class CustomImageFolder(ImageFolder):
 """
  为了得到两张图(其中一张是随机选取的)的图像和索引值信息
 """
 def __init__(self, root, transform=None):
  super(CustomImageFolder, self).__init__(root, transform)
  self.indices = range(len(self)) #该文件夹中的长度

 def __getitem__(self, index1):
  index2 = random.choice(self.indices) #从[0,indices]中随机抽取一个数字,为了随机选取一张图

  path1 = self.imgs[index1][0] #此时的self.imgs等于self.samples,即内容为[(图像路径, 该图像对应的类别索引值),(),...]
  label1 = self.imgs[index1][1]
  path2 = self.imgs[index2][0]
  label2 = self.imgs[index2][1]

  img1 = self.loader(path1)
  img2 = self.loader(path2)
  if self.transform is not None:
   img1 = self.transform(img1)
   img2 = self.transform(img2)

  return img1, img2, label1, label2

以上这篇pytorch ImageFolder的覆写实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
bpython 功能强大的Python shell
Feb 16 Python
python异常和文件处理机制详解
Jul 19 Python
基于Python实现对PDF文件的OCR识别
Aug 05 Python
python基础教程之五种数据类型详解
Jan 12 Python
python 类详解及简单实例
Mar 24 Python
Python实现购物车功能的方法分析
Nov 10 Python
python实时获取外部程序输出结果的方法
Jan 12 Python
python 实现将多条曲线画在一幅图上的方法
Jul 07 Python
详解python uiautomator2 watcher的使用方法
Sep 09 Python
分享一个pycharm专业版安装的永久使用方法
Sep 24 Python
利用python绘制正态分布曲线
Jan 04 Python
python实现监听键盘
Apr 26 Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
Python os模块常用方法和属性总结
Feb 20 #Python
Python requests获取网页常用方法解析
Feb 20 #Python
pytorch实现保证每次运行使用的随机数都相同
Feb 20 #Python
You might like
mysql建立外键
2006/11/25 PHP
简单分析ucenter 会员同步登录通信原理
2014/08/25 PHP
PHP使用DOM对XML解析处理操作示例
2019/07/04 PHP
Laravel框架源码解析之入口文件原理分析
2020/05/14 PHP
基于逻辑运算的简单权限系统(实现) JS 版
2007/03/24 Javascript
jquery 图片预加载 自动等比例缩放插件
2008/12/25 Javascript
CutePsWheel javascript libary 控制输入文本框为可使用滚轮控制的js库
2010/02/07 Javascript
日期处理的js库(迷你版)--自建js库总结
2011/11/21 Javascript
Javascript跨域请求的4种解决方式
2013/03/17 Javascript
javascript中强制执行toString()具体实现
2013/04/27 Javascript
解决js数据包含加号+通过ajax传到后台时出现连接错误
2013/08/01 Javascript
兼容主流浏览器的iframe自适应高度js脚本
2014/01/10 Javascript
jQuery实现级联下拉框实战(5)
2017/02/08 Javascript
Angular开发者指南之入门介绍
2017/03/05 Javascript
vue组件中使用props传递数据的实例详解
2018/04/08 Javascript
vue.js动画中的js钩子函数的实现
2018/07/06 Javascript
详解Eslint 配置及规则说明
2018/09/10 Javascript
JavaScript实现alert弹框效果
2020/11/19 Javascript
python使用append合并两个数组的方法
2015/04/28 Python
python with提前退出遇到的坑与解决方案
2018/01/05 Python
python pandas 如何替换某列的一个值
2018/06/09 Python
python实现自动登录后台管理系统
2018/10/18 Python
浅析python参数的知识点
2018/12/10 Python
python 实现的发送邮件模板【普通邮件、带附件、带图片邮件】
2019/07/06 Python
python中删除某个元素的方法解析
2019/11/05 Python
浅谈tensorflow中张量的提取值和赋值
2020/01/19 Python
利用matplotlib为图片上添加触发事件进行交互
2020/04/23 Python
python中setuptools的作用是什么
2020/06/19 Python
python读取xml文件方法解析
2020/08/04 Python
世界上最受欢迎的钓鱼诱饵:Rapala
2019/05/02 全球购物
潘多拉珠宝俄罗斯官方网上商店:PANDORA俄罗斯
2020/09/22 全球购物
Columbia Sportswear法国官网:全球户外品牌
2020/09/25 全球购物
2015年元旦演讲稿
2014/09/12 职场文书
2014年幼儿园园务工作总结
2014/12/05 职场文书
队列队形口号
2015/12/25 职场文书
Python基础之函数嵌套知识总结
2021/05/23 Python