pytorch ImageFolder的覆写实例


Posted in Python onFebruary 20, 2020

在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

使用可见pytorch torchvision.ImageFolder的用法介绍

这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能

首先先分析下其源代码:

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):
 """A generic data loader where the images are arranged in this way: ::

  root/dog/xxx.png
  root/dog/xxy.png
  root/dog/xxz.png

  root/cat/123.png
  root/cat/nsdf3.png
  root/cat/asd932_.png

 Args:
  root (string): Root directory path.
  transform (callable, optional): A function/transform that takes in an PIL image
   and returns a transformed version. E.g, ``transforms.RandomCrop``
  target_transform (callable, optional): A function/transform that takes in the
   target and transforms it.
  loader (callable, optional): A function to load an image given its path.

  Attributes:
  classes (list): List of the class names.
  class_to_idx (dict): Dict with items (class_name, class_index).
  imgs (list): List of (image path, class_index) tuples
 """
 def __init__(self, root, transform=None, target_transform=None,
     loader=default_loader):
  super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
           transform=transform,
           target_transform=target_transform)
  self.imgs = self.samples

ImageFolder的代码很简单,主要是继承了DatasetFolder:

def has_file_allowed_extension(filename, extensions):
 """查看文件是否是支持的可扩展类型

 Args:
  filename (string): 文件路径
  extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型

 Returns:
  bool: True if the filename ends with one of given extensions
 """
 filename_lower = filename.lower()
 return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表


def make_dataset(dir, class_to_idx, extensions):
 """
  返回形如[(图像路径, 该图像对应的类别索引值),(),...]
 """
 images = []
 dir = os.path.expanduser(dir)
 for target in sorted(class_to_idx.keys()):
  d = os.path.join(dir, target)
  if not os.path.isdir(d):
   continue

  for root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名
   for fname in sorted(fnames):
    if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可扩展类型,是则继续
     path = os.path.join(root, fname)
     item = (path, class_to_idx[target])
     images.append(item)

 return images

class DatasetFolder(data.Dataset):
 """A generic data loader where the samples are arranged in this way: ::

  root/class_x/xxx.ext
  root/class_x/xxy.ext
  root/class_x/xxz.ext

  root/class_y/123.ext
  root/class_y/nsdf3.ext
  root/class_y/asd932_.ext

 Args:
  root (string): 根目录路径
  loader (callable): 根据给定的路径来加载样本的可调用函数
  extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.
  transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本
   E.g, ``transforms.RandomCrop`` for images.
  target_transform (callable, optional): 用于样本标签的transform函数

  Attributes:
  classes (list): 类别名列表
  class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}
  samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)
  targets (list): 在数据集中每张图片的类索引值,为列表
 """

 def __init__(self, root, loader, extensions, transform=None, target_transform=None):
  classes, class_to_idx = self._find_classes(root) # 得到类名和类索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}
  # 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记
  samples = make_dataset(root, class_to_idx, extensions) 
  if len(samples) == 0:
   raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
        "Supported extensions are: " + ",".join(extensions)))

  self.root = root
  self.loader = loader
  self.extensions = extensions

  self.classes = classes
  self.class_to_idx = class_to_idx
  self.samples = samples
  self.targets = [s[1] for s in samples] #所有图像的类索引值组成的列表

  self.transform = transform
  self.target_transform = target_transform

 def _find_classes(self, dir):
  """
  在数据集中查找类文件夹。

  Args:
   dir (string): 根目录路径

  Returns:
   返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': 0, 'dog': 1}.

  Ensures:
   保证没有类名是另一个类目录的子目录
  """
  if sys.version_info >= (3, 5):
   # Faster and available in Python 3.5 and above
   classes = [d.name for d in os.scandir(dir) if d.is_dir()] #获得根目录dir的所有第一层子目录名
  else:
   classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一样,只是版本不同方法不同
  classes.sort() #然后对类名进行排序
  class_to_idx = {classes[i]: i for i in range(len(classes))} #然后将类名和索引值一一对应的到相应字典,如{'cat': 0, 'dog': 1}
  return classes, class_to_idx #然后返回类名和类索引

 def __getitem__(self, index):
  """
  Args:
   index (int): Index

  Returns:
   tuple: (sample, target) where target is class_index of the target class.
  """
  path, target = self.samples[index]
  sample = self.loader(path) # 加载图片
  if self.transform is not None:
   sample = self.transform(sample)
  if self.target_transform is not None:
   target = self.target_transform(target)

  return sample, target

 def __len__(self):
  return len(self.samples)

 def __repr__(self):
  fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
  fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
  fmt_str += ' Root Location: {}\n'.format(self.root)
  tmp = ' Transforms (if any): '
  fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  tmp = ' Target Transforms (if any): '
  fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  return fmt_str

此时想要覆写ImageFolder,代码为:

class CustomImageFolder(ImageFolder):
 """
  为了得到两张图(其中一张是随机选取的)的图像和索引值信息
 """
 def __init__(self, root, transform=None):
  super(CustomImageFolder, self).__init__(root, transform)
  self.indices = range(len(self)) #该文件夹中的长度

 def __getitem__(self, index1):
  index2 = random.choice(self.indices) #从[0,indices]中随机抽取一个数字,为了随机选取一张图

  path1 = self.imgs[index1][0] #此时的self.imgs等于self.samples,即内容为[(图像路径, 该图像对应的类别索引值),(),...]
  label1 = self.imgs[index1][1]
  path2 = self.imgs[index2][0]
  label2 = self.imgs[index2][1]

  img1 = self.loader(path1)
  img2 = self.loader(path2)
  if self.transform is not None:
   img1 = self.transform(img1)
   img2 = self.transform(img2)

  return img1, img2, label1, label2

以上这篇pytorch ImageFolder的覆写实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python文件的md5加密方法
Apr 06 Python
Python实现感知器模型、两层神经网络
Dec 19 Python
Python实现抢购IPhone手机
Feb 07 Python
kafka-python批量发送数据的实例
Dec 27 Python
PyQt弹出式对话框的常用方法及标准按钮类型
Feb 27 Python
python将excel转换为csv的代码方法总结
Jul 03 Python
tensorflow 初始化未初始化的变量实例
Feb 06 Python
Python如何使用turtle库绘制图形
Feb 26 Python
简单了解Java Netty Reactor三种线程模型
Apr 26 Python
Python正则表达式如何匹配中文
May 27 Python
完美解决Pycharm中matplotlib画图中文乱码问题
Jan 11 Python
python 第三方库paramiko的常用方式
Feb 20 Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
Python os模块常用方法和属性总结
Feb 20 #Python
Python requests获取网页常用方法解析
Feb 20 #Python
pytorch实现保证每次运行使用的随机数都相同
Feb 20 #Python
You might like
非常好的php目录导航文件代码
2006/10/09 PHP
PHP设计模式之命令模式的深入解析
2013/06/13 PHP
PHP使用JSON和将json还原成数组
2015/02/12 PHP
PHP导入导出Excel代码
2015/07/07 PHP
多浏览器兼容的获取元素和鼠标的位置的js代码
2009/12/15 Javascript
有关DOM元素与事件的3个谜题
2010/11/11 Javascript
TBCompressor js代码压缩
2011/01/05 Javascript
Js+Flash实现访问剪切板操作
2012/11/20 Javascript
本地图片预览(支持IE6/IE7/IE8/Firefox3)经验总结
2013/03/25 Javascript
jQuery实现二级下拉菜单效果
2016/01/05 Javascript
jQuery Validation Plugin验证插件手动验证
2016/01/26 Javascript
js仿QQ中对联系人向左滑动、滑出删除按钮的操作
2016/04/07 Javascript
前端弹出对话框 js实现ajax交互
2016/09/09 Javascript
基于jQuery实现弹幕APP
2017/02/10 Javascript
angular内置provider之$compileProvider详解
2017/09/27 Javascript
使用JS代码实现俄罗斯方块游戏
2018/08/03 Javascript
Vue-input框checkbox强制刷新问题
2019/04/18 Javascript
vue 进阶之实现父子组件间的传值
2019/04/26 Javascript
JavaScript实现拖拽和缩放效果
2020/08/24 Javascript
vue-axios同时请求多个接口 等所有接口全部加载完成再处理操作
2020/11/09 Javascript
vue-resource 拦截器interceptors使用详解
2021/01/18 Vue.js
PHP webshell检查工具 python实现代码
2009/09/15 Python
详解python中executemany和序列的使用方法
2017/08/12 Python
DataFrame中去除指定列为空的行方法
2018/04/08 Python
Python判断字符串是否xx开始或结尾的示例
2019/08/08 Python
Windows下PyCharm2018.3.2 安装教程(图文详解)
2019/10/24 Python
Python HTMLTestRunner测试报告view按钮失效解决方案
2020/05/25 Python
香港太阳眼镜网上商店:SmartBuyGlasses香港
2016/07/22 全球购物
墨西哥网上购物:Linio墨西哥
2016/10/20 全球购物
师范应届生语文教师求职信
2013/10/29 职场文书
公务员个人自我评价分享
2013/11/06 职场文书
设计大赛策划方案
2014/06/13 职场文书
2014年学生工作总结
2014/11/20 职场文书
2014年保卫科工作总结
2014/12/05 职场文书
企业安全隐患排查治理制度
2015/08/05 职场文书
Win11 Build 21996.1 Dev版怎么样? win11系统截图欣赏
2021/11/21 数码科技