使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
使用PYTHON创建XML文档
Mar 01 Python
使用python实现生成用户信息
Mar 20 Python
浅谈python函数之作用域(python3.5)
Oct 27 Python
利用Python进行异常值分析实例代码
Dec 07 Python
Python+树莓派+YOLO打造一款人工智能照相机
Jan 02 Python
Python实现的KMeans聚类算法实例分析
Dec 29 Python
Python获取Redis所有Key以及内容的方法
Feb 19 Python
详解Python 定时框架 Apscheduler原理及安装过程
Jun 14 Python
PyQt5 对图片进行缩放的实例
Jun 18 Python
python 读写文件包含多种编码格式的解决方式
Dec 20 Python
python实现简单坦克大战
Mar 27 Python
如何将tensorflow训练好的模型移植到Android (MNIST手写数字识别)
Apr 22 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
PHP实现采集抓取淘宝网单个商品信息
2015/01/08 PHP
javascript网页关闭时提醒效果脚本
2008/10/22 Javascript
不同浏览器对回车提交表单的处理办法
2010/02/13 Javascript
windows下安装nodejs及框架express
2015/08/07 NodeJs
jquery获取url参数及url加参数的方法
2015/10/26 Javascript
Bootstrap每天必学之导航组件
2016/04/25 Javascript
vue父子组件的数据传递示例
2017/03/07 Javascript
微信小程序 数据遍历的实现
2017/04/05 Javascript
iscroll.js滚动加载实例详解
2017/07/18 Javascript
javascript中的隐式调用
2018/02/10 Javascript
使用Vuex解决Vue中的身份验证问题
2018/09/28 Javascript
Vue+element 解决浏览器自动填充记住的账号密码问题
2019/06/11 Javascript
[57:18]DOTA2上海特级锦标赛主赛事日 - 1 败者组第一轮#3VP VS VG
2016/03/03 DOTA
[58:32]EG vs Liquid 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
Python基于PycURL实现POST的方法
2015/07/25 Python
Python读取properties配置文件操作示例
2018/03/29 Python
Tensorflow 查看变量的值方法
2018/06/14 Python
PyCharm在新窗口打开项目的方法
2019/01/17 Python
使用python快速实现不同机器间文件夹共享方式
2019/12/22 Python
Linux下升级安装python3.8并配置pip及yum的教程
2020/01/02 Python
Python如何使用ConfigParser读取配置文件
2020/11/12 Python
css3中less实现文字长阴影(long shadow)
2020/04/24 HTML / CSS
美国知名平价彩妆品牌:e.l.f. Cosmetics
2017/11/20 全球购物
师范院校学生自荐信范文
2013/12/27 职场文书
户外用品商店创业计划书
2014/01/29 职场文书
美术教师岗位职责
2014/03/18 职场文书
迎国庆演讲稿
2014/09/15 职场文书
先进基层党组织材料
2014/12/25 职场文书
成绩单家长意见
2015/06/03 职场文书
安全第一课观后感
2015/06/18 职场文书
无婚姻登记记录证明
2015/06/18 职场文书
2016年综治宣传月活动宣传标语口号
2016/03/16 职场文书
让人瞬间清醒的句子,句句经典,字字如金
2019/07/08 职场文书
范文之农村基层党建工作报告
2019/10/24 职场文书
Python中os模块的简单使用及重命名操作
2021/04/17 Python
Apache Kafka 分区重分配的实现原理解析
2022/07/15 Servers