使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
python在指定目录下查找gif文件的方法
May 04 Python
python相似模块用例
Mar 04 Python
Python 制作糗事百科爬虫实例
Sep 22 Python
Python Django 实现简单注册功能过程详解
Jul 29 Python
Django项目后台不挂断运行的方法
Aug 31 Python
翻转数列python实现,求前n项和,并能输出整个数列的案例
May 03 Python
python的json包位置及用法总结
Jun 21 Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 Python
PyTorch如何搭建一个简单的网络
Aug 24 Python
Python连接mysql方法及常用参数
Sep 01 Python
python 多线程中join()的作用
Oct 29 Python
Jupyter notebook 输出部分显示不全的解决方案
Apr 24 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
PHP中路径问题的解决方案
2006/10/09 PHP
1.PHP简介
2006/10/09 PHP
为查询结果建立向后/向前按钮
2006/10/09 PHP
php 调用远程url的六种方法小结
2009/11/02 PHP
PHP使用imagick扩展实现合并图像的方法
2017/04/25 PHP
Js callBack 返回前一页的js方法
2008/11/30 Javascript
JavaScript 原型继承
2011/12/26 Javascript
Uglifyjs(JS代码优化工具)入门 安装使用
2020/04/13 Javascript
appendChild() 或 insertBefore()使用与区别介绍
2013/10/11 Javascript
JavaScript中switch判断容易犯错的一个细节
2014/08/27 Javascript
推荐8款jQuery轻量级树形Tree插件
2014/11/12 Javascript
基于RequireJS和JQuery的模块化编程日常问题解析
2016/04/14 Javascript
codeMirror插件使用讲解
2017/01/16 Javascript
详谈jQuery.load()和Jsp的include的区别
2017/04/12 jQuery
Angular.js通过自定义指令directive实现滑块滑动效果
2017/10/13 Javascript
在react-router4中进行代码拆分的方法(基于webpack)
2018/03/08 Javascript
Vue自定义弹窗指令的实现代码
2018/08/13 Javascript
Vue3 源码导读(推荐)
2019/10/14 Javascript
vue使用swiper实现左右滑动切换图片
2020/10/16 Javascript
[06:24]DOTA2 2015国际邀请赛中国区预选赛第二日TOP10
2015/05/27 DOTA
Python牛刀小试密码爆破
2011/02/03 Python
详解Python中的文件操作
2016/08/28 Python
python 专题九 Mysql数据库编程基础知识
2017/03/16 Python
通过celery异步处理一个查询任务的完整代码
2019/11/19 Python
PyCharm 2020.2.2 x64 下载并安装的详细教程
2020/10/15 Python
Python通过Schema实现数据验证方式
2020/11/12 Python
python requests库的使用
2021/01/06 Python
高三毕业生自我鉴定
2013/12/20 职场文书
大学生活动策划方案
2014/02/10 职场文书
竞聘书格式及范文
2014/03/31 职场文书
幼儿园师德师风学习材料
2014/05/29 职场文书
技术股东合作协议书
2014/12/02 职场文书
2019年健身俱乐部的创业计划书
2019/08/26 职场文书
关于考试抄袭的检讨书
2019/11/02 职场文书
nginx里的rewrite跳转的实现
2021/03/31 Servers
MySQL中的全表扫描和索引树扫描
2022/05/15 MySQL