pytorch ImageFolder的覆写实例


Posted in Python onFebruary 20, 2020

在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

使用可见pytorch torchvision.ImageFolder的用法介绍

这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能

首先先分析下其源代码:

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):
 """A generic data loader where the images are arranged in this way: ::

  root/dog/xxx.png
  root/dog/xxy.png
  root/dog/xxz.png

  root/cat/123.png
  root/cat/nsdf3.png
  root/cat/asd932_.png

 Args:
  root (string): Root directory path.
  transform (callable, optional): A function/transform that takes in an PIL image
   and returns a transformed version. E.g, ``transforms.RandomCrop``
  target_transform (callable, optional): A function/transform that takes in the
   target and transforms it.
  loader (callable, optional): A function to load an image given its path.

  Attributes:
  classes (list): List of the class names.
  class_to_idx (dict): Dict with items (class_name, class_index).
  imgs (list): List of (image path, class_index) tuples
 """
 def __init__(self, root, transform=None, target_transform=None,
     loader=default_loader):
  super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
           transform=transform,
           target_transform=target_transform)
  self.imgs = self.samples

ImageFolder的代码很简单,主要是继承了DatasetFolder:

def has_file_allowed_extension(filename, extensions):
 """查看文件是否是支持的可扩展类型

 Args:
  filename (string): 文件路径
  extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型

 Returns:
  bool: True if the filename ends with one of given extensions
 """
 filename_lower = filename.lower()
 return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表


def make_dataset(dir, class_to_idx, extensions):
 """
  返回形如[(图像路径, 该图像对应的类别索引值),(),...]
 """
 images = []
 dir = os.path.expanduser(dir)
 for target in sorted(class_to_idx.keys()):
  d = os.path.join(dir, target)
  if not os.path.isdir(d):
   continue

  for root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名
   for fname in sorted(fnames):
    if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可扩展类型,是则继续
     path = os.path.join(root, fname)
     item = (path, class_to_idx[target])
     images.append(item)

 return images

class DatasetFolder(data.Dataset):
 """A generic data loader where the samples are arranged in this way: ::

  root/class_x/xxx.ext
  root/class_x/xxy.ext
  root/class_x/xxz.ext

  root/class_y/123.ext
  root/class_y/nsdf3.ext
  root/class_y/asd932_.ext

 Args:
  root (string): 根目录路径
  loader (callable): 根据给定的路径来加载样本的可调用函数
  extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.
  transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本
   E.g, ``transforms.RandomCrop`` for images.
  target_transform (callable, optional): 用于样本标签的transform函数

  Attributes:
  classes (list): 类别名列表
  class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}
  samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)
  targets (list): 在数据集中每张图片的类索引值,为列表
 """

 def __init__(self, root, loader, extensions, transform=None, target_transform=None):
  classes, class_to_idx = self._find_classes(root) # 得到类名和类索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}
  # 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记
  samples = make_dataset(root, class_to_idx, extensions) 
  if len(samples) == 0:
   raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
        "Supported extensions are: " + ",".join(extensions)))

  self.root = root
  self.loader = loader
  self.extensions = extensions

  self.classes = classes
  self.class_to_idx = class_to_idx
  self.samples = samples
  self.targets = [s[1] for s in samples] #所有图像的类索引值组成的列表

  self.transform = transform
  self.target_transform = target_transform

 def _find_classes(self, dir):
  """
  在数据集中查找类文件夹。

  Args:
   dir (string): 根目录路径

  Returns:
   返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': 0, 'dog': 1}.

  Ensures:
   保证没有类名是另一个类目录的子目录
  """
  if sys.version_info >= (3, 5):
   # Faster and available in Python 3.5 and above
   classes = [d.name for d in os.scandir(dir) if d.is_dir()] #获得根目录dir的所有第一层子目录名
  else:
   classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一样,只是版本不同方法不同
  classes.sort() #然后对类名进行排序
  class_to_idx = {classes[i]: i for i in range(len(classes))} #然后将类名和索引值一一对应的到相应字典,如{'cat': 0, 'dog': 1}
  return classes, class_to_idx #然后返回类名和类索引

 def __getitem__(self, index):
  """
  Args:
   index (int): Index

  Returns:
   tuple: (sample, target) where target is class_index of the target class.
  """
  path, target = self.samples[index]
  sample = self.loader(path) # 加载图片
  if self.transform is not None:
   sample = self.transform(sample)
  if self.target_transform is not None:
   target = self.target_transform(target)

  return sample, target

 def __len__(self):
  return len(self.samples)

 def __repr__(self):
  fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
  fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
  fmt_str += ' Root Location: {}\n'.format(self.root)
  tmp = ' Transforms (if any): '
  fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  tmp = ' Target Transforms (if any): '
  fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  return fmt_str

此时想要覆写ImageFolder,代码为:

class CustomImageFolder(ImageFolder):
 """
  为了得到两张图(其中一张是随机选取的)的图像和索引值信息
 """
 def __init__(self, root, transform=None):
  super(CustomImageFolder, self).__init__(root, transform)
  self.indices = range(len(self)) #该文件夹中的长度

 def __getitem__(self, index1):
  index2 = random.choice(self.indices) #从[0,indices]中随机抽取一个数字,为了随机选取一张图

  path1 = self.imgs[index1][0] #此时的self.imgs等于self.samples,即内容为[(图像路径, 该图像对应的类别索引值),(),...]
  label1 = self.imgs[index1][1]
  path2 = self.imgs[index2][0]
  label2 = self.imgs[index2][1]

  img1 = self.loader(path1)
  img2 = self.loader(path2)
  if self.transform is not None:
   img1 = self.transform(img1)
   img2 = self.transform(img2)

  return img1, img2, label1, label2

以上这篇pytorch ImageFolder的覆写实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用第三方库xlutils来追加写入Excel文件示例
Apr 05 Python
Python中比较特别的除法运算和幂运算介绍
Apr 05 Python
Python标准库defaultdict模块使用示例
Apr 28 Python
python计算对角线有理函数插值的方法
May 07 Python
Python中%r和%s的详解及区别
Mar 16 Python
OpenCV实现人脸识别
Apr 07 Python
python+opencv识别图片中的圆形
Mar 25 Python
使用Python爬虫库BeautifulSoup遍历文档树并对标签进行操作详解
Jan 25 Python
python脚本和网页有何区别
Jul 02 Python
Python猫眼电影最近上映的电影票房信息
Sep 18 Python
Python+OpenCV图像处理——实现轮廓发现
Oct 23 Python
Django模板报TemplateDoesNotExist异常(亲测可行)
Dec 18 Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
Python os模块常用方法和属性总结
Feb 20 #Python
Python requests获取网页常用方法解析
Feb 20 #Python
pytorch实现保证每次运行使用的随机数都相同
Feb 20 #Python
You might like
PHP 高手之路(二)
2006/10/09 PHP
PHP 存取 MySQL 数据库的一个例子
2006/10/09 PHP
php数组合并array_merge()函数使用注意事项
2014/06/19 PHP
PHP 计算至少是其他数字两倍的最大数的实现代码
2020/05/26 PHP
下载文件个别浏览器文件名乱码解决办法
2013/03/19 Javascript
js实现浏览器的各种菜单命令比如打印、查看源文件等等
2013/10/24 Javascript
模拟用户点击弹出新页面不会被浏览器拦截
2014/04/08 Javascript
jQuery简单图表peity.js使用示例
2014/05/02 Javascript
如何调试异步加载页面里包含的js文件
2014/10/30 Javascript
SpringMVC返回json数据的三种方式
2015/12/10 Javascript
Node.js程序中的本地文件操作用法小结
2016/03/06 Javascript
JavaScript的Ext JS框架中的GridPanel组件使用指南
2016/05/21 Javascript
js 定位到某个锚点的方法
2016/11/19 Javascript
JavaScript之浏览器对象_动力节点Java学院整理
2017/07/03 Javascript
vue 项目常用加载器及配置详解
2018/01/22 Javascript
详解Vue 全局引入bass.scss 处理方案
2018/03/26 Javascript
详解angular2 控制视图的封装模式
2018/12/27 Javascript
JavaScript实现数字前补“0”的五种方法示例
2019/01/03 Javascript
react 原生实现头像滚动播放的示例
2020/04/21 Javascript
[45:32]Liquid vs LGD 2018国际邀请赛淘汰赛BO3 第二场 8.23
2018/08/24 DOTA
Python Web框架Flask信号机制(signals)介绍
2015/01/01 Python
cmd运行python文件时对结果进行保存的方法
2018/05/16 Python
Python使用指定字符长度切分数据示例
2019/12/05 Python
python 日志模块 日志等级设置失效的解决方案
2020/05/26 Python
python搜索算法原理及实例讲解
2020/11/18 Python
python接口自动化框架实战
2020/12/23 Python
Python3+PyCharm+Django+Django REST framework配置与简单开发教程
2021/02/16 Python
Python中Qslider控件实操详解
2021/02/20 Python
中国综合性网上购物商城:当当(网上卖书起家)
2016/11/16 全球购物
P D PAOLA法国官网:西班牙著名的珠宝首饰品牌
2020/02/15 全球购物
Java中的类包括什么内容?设计时要注意哪些方面
2012/05/23 面试题
质检部岗位职责
2013/11/11 职场文书
母亲80寿诞答谢词
2014/01/16 职场文书
社会工作专业求职信
2014/07/15 职场文书
python元组打包和解包过程详解
2021/08/02 Python
Golang 实现WebSockets
2022/04/24 Golang