使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
Python3读取Excel数据存入MySQL的方法
May 04 Python
pycharm 主题theme设置调整仿sublime的方法
May 23 Python
windows下cx_Freeze生成Python可执行程序的详细步骤
Oct 09 Python
Python后台开发Django的教程详解(启动)
Apr 08 Python
python实现定时压缩指定文件夹发送邮件
Dec 22 Python
pandas实现将日期转换成timestamp
Dec 07 Python
pandas的相关系数与协方差实例
Dec 27 Python
python清空命令行方式
Jan 13 Python
python re模块匹配贪婪和非贪婪模式详解
Feb 11 Python
python可迭代对象去重实例
May 15 Python
python interpolate插值实例
Jul 06 Python
基于Python爬取fofa网页端数据过程解析
Jul 13 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
PHP 内存缓存加速功能memcached安装与用法
2009/09/03 PHP
set_include_path和get_include_path使用及注意事项
2013/02/02 PHP
php定义一个参数带有默认值的函数实例分析
2015/03/16 PHP
php实现给二维数组中所有一维数组添加值的方法
2017/02/04 PHP
TP(thinkPHP)框架多层控制器和多级控制器的使用示例
2018/06/13 PHP
Yii框架模拟组件调用注入示例
2019/11/11 PHP
在JavaScript里嵌入大量字符串常量的实现方法
2013/07/07 Javascript
js判断设备是否为PC并调整图片大小
2014/02/12 Javascript
Javascript实现跑马灯效果的简单实例
2016/05/31 Javascript
jQuery弹出窗口打开链接的实现代码
2016/12/24 Javascript
jQuery简单绑定单个事件的方法示例
2017/06/10 jQuery
vue中使用gojs/jointjs的示例代码
2018/08/24 Javascript
详解JavaScript中关于this指向的4种情况
2019/04/18 Javascript
vue v-for直接循环数字实例
2019/11/07 Javascript
[51:53]完美世界DOTA2联赛循环赛 LBZS vs DM BO2第二场 11.01
2020/11/02 DOTA
详解常用查找数据结构及算法(Python实现)
2016/12/09 Python
python2与python3的print及字符串格式化小结
2018/11/30 Python
Python实现合并两个有序链表的方法示例
2019/01/31 Python
Python中的random.uniform()函数教程与实例解析
2019/03/02 Python
Python向excel中写入数据的方法
2019/05/05 Python
Django项目之Elasticsearch搜索引擎的实例
2019/08/21 Python
TensorFlow Saver:保存和读取模型参数.ckpt实例
2020/02/10 Python
pandas DataFrame 数据选取,修改,切片的实现
2020/04/24 Python
Python实现清理微信僵尸粉功能示例【基于itchat模块】
2020/05/29 Python
北美三大旅游网站之一:Travelocity加拿大
2016/08/20 全球购物
北美个性化礼品商店:Things Remembered
2018/06/12 全球购物
行政管理专业求职信
2014/07/06 职场文书
村长反四风问题个人对照检查材料
2014/09/21 职场文书
机票销售员态度不好检讨书
2014/09/27 职场文书
党的群众路线教育实践活动心得体会(医院)
2014/11/03 职场文书
2014年社区计生工作总结
2014/11/18 职场文书
小学优秀教师材料
2014/12/15 职场文书
色戒观后感
2015/06/12 职场文书
2016年学校爱国卫生月活动总结
2016/04/06 职场文书
导游词之晋城蟒河
2019/12/12 职场文书
读《方与圆》有感:交友方圆有度
2020/01/14 职场文书