pytorch ImageFolder的覆写实例


Posted in Python onFebruary 20, 2020

在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

使用可见pytorch torchvision.ImageFolder的用法介绍

这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能

首先先分析下其源代码:

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):
 """A generic data loader where the images are arranged in this way: ::

  root/dog/xxx.png
  root/dog/xxy.png
  root/dog/xxz.png

  root/cat/123.png
  root/cat/nsdf3.png
  root/cat/asd932_.png

 Args:
  root (string): Root directory path.
  transform (callable, optional): A function/transform that takes in an PIL image
   and returns a transformed version. E.g, ``transforms.RandomCrop``
  target_transform (callable, optional): A function/transform that takes in the
   target and transforms it.
  loader (callable, optional): A function to load an image given its path.

  Attributes:
  classes (list): List of the class names.
  class_to_idx (dict): Dict with items (class_name, class_index).
  imgs (list): List of (image path, class_index) tuples
 """
 def __init__(self, root, transform=None, target_transform=None,
     loader=default_loader):
  super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
           transform=transform,
           target_transform=target_transform)
  self.imgs = self.samples

ImageFolder的代码很简单,主要是继承了DatasetFolder:

def has_file_allowed_extension(filename, extensions):
 """查看文件是否是支持的可扩展类型

 Args:
  filename (string): 文件路径
  extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型

 Returns:
  bool: True if the filename ends with one of given extensions
 """
 filename_lower = filename.lower()
 return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表


def make_dataset(dir, class_to_idx, extensions):
 """
  返回形如[(图像路径, 该图像对应的类别索引值),(),...]
 """
 images = []
 dir = os.path.expanduser(dir)
 for target in sorted(class_to_idx.keys()):
  d = os.path.join(dir, target)
  if not os.path.isdir(d):
   continue

  for root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名
   for fname in sorted(fnames):
    if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可扩展类型,是则继续
     path = os.path.join(root, fname)
     item = (path, class_to_idx[target])
     images.append(item)

 return images

class DatasetFolder(data.Dataset):
 """A generic data loader where the samples are arranged in this way: ::

  root/class_x/xxx.ext
  root/class_x/xxy.ext
  root/class_x/xxz.ext

  root/class_y/123.ext
  root/class_y/nsdf3.ext
  root/class_y/asd932_.ext

 Args:
  root (string): 根目录路径
  loader (callable): 根据给定的路径来加载样本的可调用函数
  extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.
  transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本
   E.g, ``transforms.RandomCrop`` for images.
  target_transform (callable, optional): 用于样本标签的transform函数

  Attributes:
  classes (list): 类别名列表
  class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}
  samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)
  targets (list): 在数据集中每张图片的类索引值,为列表
 """

 def __init__(self, root, loader, extensions, transform=None, target_transform=None):
  classes, class_to_idx = self._find_classes(root) # 得到类名和类索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}
  # 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记
  samples = make_dataset(root, class_to_idx, extensions) 
  if len(samples) == 0:
   raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
        "Supported extensions are: " + ",".join(extensions)))

  self.root = root
  self.loader = loader
  self.extensions = extensions

  self.classes = classes
  self.class_to_idx = class_to_idx
  self.samples = samples
  self.targets = [s[1] for s in samples] #所有图像的类索引值组成的列表

  self.transform = transform
  self.target_transform = target_transform

 def _find_classes(self, dir):
  """
  在数据集中查找类文件夹。

  Args:
   dir (string): 根目录路径

  Returns:
   返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': 0, 'dog': 1}.

  Ensures:
   保证没有类名是另一个类目录的子目录
  """
  if sys.version_info >= (3, 5):
   # Faster and available in Python 3.5 and above
   classes = [d.name for d in os.scandir(dir) if d.is_dir()] #获得根目录dir的所有第一层子目录名
  else:
   classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一样,只是版本不同方法不同
  classes.sort() #然后对类名进行排序
  class_to_idx = {classes[i]: i for i in range(len(classes))} #然后将类名和索引值一一对应的到相应字典,如{'cat': 0, 'dog': 1}
  return classes, class_to_idx #然后返回类名和类索引

 def __getitem__(self, index):
  """
  Args:
   index (int): Index

  Returns:
   tuple: (sample, target) where target is class_index of the target class.
  """
  path, target = self.samples[index]
  sample = self.loader(path) # 加载图片
  if self.transform is not None:
   sample = self.transform(sample)
  if self.target_transform is not None:
   target = self.target_transform(target)

  return sample, target

 def __len__(self):
  return len(self.samples)

 def __repr__(self):
  fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
  fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
  fmt_str += ' Root Location: {}\n'.format(self.root)
  tmp = ' Transforms (if any): '
  fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  tmp = ' Target Transforms (if any): '
  fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  return fmt_str

此时想要覆写ImageFolder,代码为:

class CustomImageFolder(ImageFolder):
 """
  为了得到两张图(其中一张是随机选取的)的图像和索引值信息
 """
 def __init__(self, root, transform=None):
  super(CustomImageFolder, self).__init__(root, transform)
  self.indices = range(len(self)) #该文件夹中的长度

 def __getitem__(self, index1):
  index2 = random.choice(self.indices) #从[0,indices]中随机抽取一个数字,为了随机选取一张图

  path1 = self.imgs[index1][0] #此时的self.imgs等于self.samples,即内容为[(图像路径, 该图像对应的类别索引值),(),...]
  label1 = self.imgs[index1][1]
  path2 = self.imgs[index2][0]
  label2 = self.imgs[index2][1]

  img1 = self.loader(path1)
  img2 = self.loader(path2)
  if self.transform is not None:
   img1 = self.transform(img1)
   img2 = self.transform(img2)

  return img1, img2, label1, label2

以上这篇pytorch ImageFolder的覆写实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常见格式化字符串方法小结【百分号与format方法】
Sep 18 Python
python 实现在txt指定行追加文本的方法
Apr 29 Python
关于django 数据库迁移(migrate)应该知道的一些事
May 27 Python
使用Django连接Mysql数据库步骤
Jan 15 Python
解决python中画图时x,y轴名称出现中文乱码的问题
Jan 29 Python
Python (Win)readline和tab补全的安装方法
Aug 27 Python
PYTHON实现SIGN签名的过程解析
Oct 28 Python
解决os.path.isdir() 判断文件夹却返回false的问题
Nov 29 Python
python标准库OS模块函数列表与实例全解
Mar 10 Python
Python尾递归优化实现代码及原理详解
Oct 09 Python
详解python模块pychartdir安装及导入问题
Oct 22 Python
从np.random.normal()到正态分布的拟合操作
Jun 02 Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
Python os模块常用方法和属性总结
Feb 20 #Python
Python requests获取网页常用方法解析
Feb 20 #Python
pytorch实现保证每次运行使用的随机数都相同
Feb 20 #Python
You might like
在PHP的图形函数中显示汉字
2006/10/09 PHP
php中的三元运算符使用说明
2011/07/03 PHP
smarty内置函数config_load用法实例
2015/01/22 PHP
php提交过来的数据生成为txt文件
2016/04/28 PHP
php微信开发之带参数二维码的使用
2016/08/03 PHP
PHP入门教程之面向对象基本概念实例分析
2016/09/11 PHP
javascript 面向对象编程基础:继承
2009/08/21 Javascript
Javascript 获取链接(url)参数的方法[正则与截取字符串]
2010/02/09 Javascript
使用Jquery来实现可以输入值的下拉选单 雏型
2011/12/06 Javascript
js控制不同的时间段显示不同的css样式的实例代码
2013/11/04 Javascript
jQuery$命名冲突怎么办如何解决
2014/01/16 Javascript
JavaScript监听和禁用浏览器回车事件实例
2015/01/31 Javascript
基于javascript实现句子翻牌网页版小游戏
2016/03/23 Javascript
jQuery fadeOut 异步实例代码详解
2016/08/18 Javascript
nodejs实现超简单生成二维码的方法
2018/03/17 NodeJs
安装vue-cli的简易过程
2018/05/22 Javascript
a标签调用js的方法总结
2019/09/05 Javascript
JS将指定的某个字符全部转换为其他字符实例代码
2020/10/13 Javascript
解决vue 使用axios.all()方法发起多个请求控制台报错的问题
2020/11/09 Javascript
[00:36]DOTA2上海特级锦标赛 Archon战队宣传片
2016/03/04 DOTA
[39:00]Optic vs VP 2018国际邀请赛淘汰赛BO3 第三场 8.24
2018/08/25 DOTA
django 发送手机验证码的示例代码
2018/04/25 Python
Linux系统(CentOS)下python2.7.10安装
2018/09/26 Python
Python使用paramiko操作linux的方法讲解
2019/02/25 Python
浅析Python与Mongodb数据库之间的操作方法
2019/07/01 Python
python 视频逐帧保存为图片的完整实例
2019/12/10 Python
Python txt文件如何转换成字典
2020/11/03 Python
Python函数调用追踪实现代码
2020/11/27 Python
党校培训自我鉴定范文
2014/03/20 职场文书
学生鉴定评语大全
2014/05/05 职场文书
甲乙双方合作协议书
2014/10/13 职场文书
工会2014法制宣传日活动总结
2014/11/01 职场文书
商场圣诞节活动总结
2015/05/06 职场文书
小学运动会前导词
2015/07/20 职场文书
导游词之河北邯郸
2019/09/12 职场文书
js 实现验证码输入框示例详解
2022/09/23 Javascript