使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
Python中尝试多线程编程的一个简明例子
Apr 07 Python
Python机器学习库scikit-learn安装与基本使用教程
Jun 25 Python
基于腾讯云服务器部署微信小程序后台服务(Python+Django)
May 08 Python
Python 3.8 新功能全解
Jul 25 Python
django多对多表的创建,级联删除及手动创建第三张表
Jul 25 Python
Python使用指定字符长度切分数据示例
Dec 05 Python
Django中从mysql数据库中获取数据传到echarts方式
Apr 07 Python
pyspark给dataframe增加新的一列的实现示例
Apr 24 Python
Python实现一个优先级队列的方法
Jul 31 Python
python map比for循环快在哪
Sep 21 Python
Python 下载Bing壁纸的示例
Sep 29 Python
matplotlib bar()实现多组数据并列柱状图通用简便创建方法
Feb 24 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
用PHP创建PDF中文文档
2006/10/09 PHP
人大复印资料处理程序_查询篇
2006/10/09 PHP
五个PHP程序员工具
2008/05/26 PHP
php递归获取目录内文件(包含子目录)封装类分享
2013/12/25 PHP
php二分查找二种实现示例
2014/03/12 PHP
php实现XSS安全过滤的方法
2015/07/29 PHP
javascript中的startWith和endWith的几种实现方法
2013/05/07 Javascript
jquery实现的美女拼图游戏实例
2015/05/04 Javascript
第五篇Bootstrap 排版
2016/06/21 Javascript
总结Javascript中的隐式类型转换
2016/08/24 Javascript
javascript基本数据类型及类型检测常用方法小结
2016/12/14 Javascript
node.JS md5加密中文与php结果不一致的解决方法
2017/05/05 Javascript
解决Angular4项目部署到服务器上刷新404的问题
2018/08/31 Javascript
微信小程序实现炫酷的弹出式菜单特效
2019/01/28 Javascript
vue中typescript装饰器的使用方法超实用教程
2019/06/17 Javascript
vue flex 布局实现div均分自动换行的示例代码
2020/08/05 Javascript
JavaScript实现网页下拉菜单效果
2020/11/20 Javascript
[02:36]DOTA2混沌骑士 英雄基础教程
2013/11/26 DOTA
Python+Django在windows下的开发环境配置图解
2009/11/11 Python
跟老齐学Python之有容乃大的list(2)
2014/09/15 Python
python在ubuntu中的几种安装方法(小结)
2017/12/08 Python
Python中安装easy_install的方法
2018/11/18 Python
flask框架自定义过滤器示例【markdown文件读取和展示功能】
2019/11/08 Python
如何使用python进行pdf文件分割
2019/11/11 Python
python3-flask-3将信息写入日志的实操方法
2019/11/12 Python
python 实现字符串下标的输出功能
2020/02/13 Python
python生成大写32位uuid代码
2020/03/03 Python
基于Tensorflow一维卷积用法详解
2020/05/22 Python
浅析Python requests 模块
2020/10/09 Python
凯普林包包西班牙官网:Kipling西班牙
2019/04/12 全球购物
八项规定整改方案
2014/02/21 职场文书
2015年电工工作总结
2015/04/10 职场文书
2015年医院护理部工作总结
2015/04/23 职场文书
单位更名证明
2015/06/18 职场文书
2016十一国庆节慰问信
2015/12/01 职场文书
基于Python绘制子图及子图刻度的变换等的问题
2021/05/23 Python