使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
python不带重复的全排列代码
Aug 13 Python
python实现udp数据报传输的方法
Sep 26 Python
Python入门篇之对象类型
Oct 17 Python
使用Python从有道词典网页获取单词翻译
Jul 03 Python
PyChar学习教程之自定义文件与代码模板详解
Jul 17 Python
Python复制Word内容并使用格式设字体与大小实例代码
Jan 22 Python
python实现数独游戏 java简单实现数独游戏
Mar 30 Python
使用Python的toolz库开始函数式编程的方法
Nov 15 Python
Python遍历字典方式就实例详解
Dec 28 Python
tensorflow 报错unitialized value的解决方法
Feb 06 Python
Python GUI编程学习笔记之tkinter控件的介绍及基本使用方法详解
Mar 30 Python
分享Python异步爬取知乎热榜
Apr 12 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
MayFish PHP的MVC架构的开发框架
2009/08/13 PHP
PHP实现登陆并抓取微信列表中最新一组微信消息的方法
2017/07/10 PHP
thinkPHP框架中layer.js的封装与使用方法示例
2019/01/18 PHP
PHP7 新增常量
2021/03/09 PHP
js实现两点之间画线的方法
2015/05/12 Javascript
超赞的动手创建JavaScript框架的详细教程
2015/06/30 Javascript
JS中的forEach、$.each、map方法推荐
2016/04/05 Javascript
jQuery通用的全局遍历方法$.each()用法实例
2016/07/04 Javascript
Move.js入门
2017/02/08 Javascript
微信小程序 地图map实例详解
2017/06/07 Javascript
bootstrap 路径导航 分页 进度条的实例代码
2018/08/06 Javascript
浅谈Vue.js 中的 v-on 事件指令的使用
2018/11/25 Javascript
JavaScript中0、空字符串、'0'是true还是false的知识点分享
2019/09/16 Javascript
js实现简单掷骰子效果
2019/10/24 Javascript
如何使用Javascript中的this关键字
2020/05/28 Javascript
[01:16]DOTA2小知识课堂 Ep.03 芒果树无伤肉山
2019/12/05 DOTA
python中使用OpenCV进行人脸检测的例子
2014/04/18 Python
详解Python多线程
2016/11/14 Python
python中Matplotlib实现绘制3D图的示例代码
2017/09/04 Python
python如何重载模块实例解析
2018/01/25 Python
解决Python print输出不换行没空格的问题
2018/11/14 Python
解决win7操作系统Python3.7.1安装后启动提示缺少.dll文件问题
2019/07/15 Python
基于Python+Appium实现京东双十一自动领金币功能
2019/10/31 Python
python将时分秒转换成秒的实例
2019/12/07 Python
解决TensorFlow程序无限制占用GPU的方法
2020/06/30 Python
Eagle Eyes Optics鹰眼光学:高性能太阳镜
2018/12/07 全球购物
我想声明一个指针并为它分配一些空间, 但却不行。这些代码有什么 问题?char *p; *p = malloc(10);
2016/10/06 面试题
九州传奇上机题
2014/07/10 面试题
医学专业本科毕业生自我鉴定
2013/12/28 职场文书
竞聘上岗演讲
2014/05/19 职场文书
人力资源管理专业求职信
2014/07/23 职场文书
餐厅周年庆活动方案
2014/08/25 职场文书
专升本学生毕业自我鉴定
2014/10/04 职场文书
教育实习指导教师评语
2014/12/31 职场文书
三好学生主要事迹材料
2015/11/03 职场文书
2016领导干部廉洁自律心得体会
2016/01/13 职场文书