使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
详解django中自定义标签和过滤器
Jul 03 Python
Python轻量级ORM框架Peewee访问sqlite数据库的方法详解
Jul 20 Python
Python线性方程组求解运算示例
Jan 17 Python
Tensorflow环境搭建的方法步骤
Feb 07 Python
python生成n个元素的全组合方法
Nov 13 Python
Python shutil模块用法实例分析
Oct 02 Python
selenium WebDriverWait类等待机制的实现
Mar 18 Python
UI自动化定位常用实现方法代码示例
Oct 27 Python
python tkinter实现连连看游戏
Nov 16 Python
Python控制鼠标键盘代码实例
Dec 08 Python
python之django路由和视图案例教程
Jul 26 Python
C3 线性化算法与 MRO之Python中的多继承
Oct 05 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
香妃
2021/03/03 冲泡冲煮
在laravel中实现ORM模型使用第二个数据库设置
2019/10/24 PHP
jquery鼠标停止移动事件
2013/12/21 Javascript
js实现身份证号码验证的简单实例
2014/02/19 Javascript
javascript文件中引用依赖的js文件的方法
2014/03/17 Javascript
javascript使用call调用微信API
2014/12/15 Javascript
JavaScript DOM元素尺寸和位置
2015/04/13 Javascript
浅谈javascript中的闭包
2015/05/13 Javascript
今天抽时间给大家整理jquery和ajax的相关知识
2015/11/17 Javascript
JavaScript通过使用onerror设置默认图像显示代替alt
2016/03/01 Javascript
深入理解JS函数的参数(arguments)的使用
2016/05/28 Javascript
jQuery 实现ajax传入参数含有特殊字符的方法总结
2016/10/17 Javascript
jquery 标签 隔若干行加空白或者加虚线的方法
2016/12/07 Javascript
Vue修改mint-ui默认样式的方法
2018/02/03 Javascript
JavaScript实现计算圆周率到小数点后100位的方法示例
2018/05/08 Javascript
JQuery常见节点操作实例分析
2019/05/15 jQuery
layer扩展打开/关闭动画的方法
2019/09/23 Javascript
extjs图形绘制之饼图实现方法分析
2020/03/06 Javascript
微信小程序 wx:for 与 wx:for-items 与 wx:key的正确用法
2020/05/19 Javascript
[42:32]完美世界DOTA2联赛循环赛 Magma vs PXG BO2第二场 10.28
2020/10/28 DOTA
python中文乱码的解决方法
2013/11/04 Python
Python内置函数Type()函数一个有趣的用法
2015/02/18 Python
解决Python出现_warn_unsafe_extraction问题的方法
2016/03/24 Python
深入解析Python中的__builtins__内建对象
2016/06/21 Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
2018/05/24 Python
Python3.7安装keras和TensorFlow的教程图解
2020/06/18 Python
Python3常用内置方法代码实例
2019/11/18 Python
python 基于opencv 绘制图像轮廓
2020/12/11 Python
求网格中的黑点分布
2013/11/06 面试题
自考毕业生自我鉴定
2013/11/04 职场文书
小学教师岗位职责
2013/11/25 职场文书
《值日生》教学反思
2014/02/17 职场文书
影视后期实训报告
2014/11/05 职场文书
物业保安辞职信
2015/05/12 职场文书
人与自然观后感
2015/06/16 职场文书
《刷子李》教学反思
2016/02/20 职场文书