使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
使用python 获取进程pid号的方法
Mar 10 Python
python正则表达式match和search用法实例
Mar 26 Python
将Django框架和遗留的Web应用集成的方法
Jul 24 Python
Python使用dis模块把Python反编译为字节码的用法详解
Jun 14 Python
python+matplotlib实现动态绘制图片实例代码(交互式绘图)
Jan 20 Python
基于python神经卷积网络的人脸识别
May 24 Python
Python 查找list中的某个元素的所有的下标方法
Jun 27 Python
Python实现多属性排序的方法
Dec 05 Python
对python多线程中互斥锁Threading.Lock的简单应用详解
Jan 11 Python
Python Opencv 通过轨迹(跟踪)栏实现更改整张图像的背景颜色
Mar 09 Python
Python devel安装失败问题解决方案
Jun 09 Python
python drf各类组件的用法和作用
Jan 12 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
centos 5.6 升级php到5.3的方法
2011/05/14 PHP
PHP面向对象之旅:深入理解static变量与方法
2014/01/06 PHP
一个简单的PHP验证码实现代码
2014/05/10 PHP
php实现的click captcha点击验证码类实例
2014/09/23 PHP
PHP 5.3和PHP 5.4出现FastCGI Error解决方法
2015/02/12 PHP
JavaScript 权威指南(第四版) 读书笔记
2009/08/11 Javascript
JS+CSS实现淡入式焦点图片幻灯切换效果的方法
2015/02/26 Javascript
jQuery的css() 方法使用指南
2015/05/03 Javascript
javascript实现对表格元素进行排序操作
2015/11/18 Javascript
jQuery的框架介绍
2016/05/11 Javascript
Node.js中使用jQuery的做法
2016/08/17 Javascript
JavaScript实现DOM对象选择器
2016/09/24 Javascript
详解vue 中使用 AJAX获取数据的方法
2017/01/18 Javascript
jquery拖拽自动排序插件使用方法详解
2020/07/20 jQuery
Angular使用Restful的增删改
2018/12/28 Javascript
vue 移动端注入骨架屏的配置方法
2019/06/25 Javascript
Vue项目中数据的深度监听或对象属性的监听实例
2020/07/17 Javascript
你不知道的SpringBoot与Vue部署解决方案
2020/11/09 Javascript
Python实现的对本地host127.0.0.1主机进行扫描端口功能示例
2019/02/15 Python
python3 xpath和requests应用详解
2020/03/06 Python
Django自关联实现多级联动查询实例
2020/05/19 Python
HTML5调用手机发短信和打电话功能
2020/04/29 HTML / CSS
德国高品质男装及配饰商城:Cultizm(Raw Denim原色牛仔裤)
2018/04/16 全球购物
wedgwood加拿大官网:1759年成立的英国国宝级陶瓷餐具品牌
2018/07/17 全球购物
奢华的意大利皮革手袋:Bene Handbags
2019/10/29 全球购物
中职应届生会计求职信
2013/10/23 职场文书
出纳工作检讨书
2014/10/18 职场文书
财务检查整改报告
2014/11/06 职场文书
2015年助残日活动总结
2015/03/27 职场文书
2015年质量月活动总结报告
2015/03/27 职场文书
初中数学教学随笔
2015/08/15 职场文书
护士岗前培训心得体会
2016/01/08 职场文书
总结Python使用过程中的bug
2021/06/18 Python
go goroutine 怎样进行错误处理
2021/07/16 Golang
分享7个 Python 实战项目练习
2022/03/03 Python
漫画「处刑少女的生存之道」第3卷封面公开
2022/03/21 日漫