使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
以Flask为例讲解Python的框架的使用方法
Apr 29 Python
bpython 功能强大的Python shell
Feb 16 Python
Python脚本实现自动发带图的微博
Apr 27 Python
python 简单的绘图工具turtle使用详解
Jun 21 Python
解决python给列表里添加字典时被最后一个覆盖的问题
Jan 21 Python
django 数据库连接模块解析及简单长连接改造方法
Aug 29 Python
python socket通信编程实现文件上传代码实例
Dec 14 Python
Python list运算操作代码实例解析
Jan 20 Python
Jupyter notebook如何修改平台字体
May 13 Python
python中数字是否为可变类型
Jul 08 Python
python爬不同图片分别保存在不同文件夹中的实现
Apr 02 Python
Python 中 Shutil 模块详情
Nov 11 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
全新的PDO数据库操作类php版(仅适用Mysql)
2012/07/22 PHP
PHP7常量数组用法分析
2016/09/26 PHP
Yii2框架可逆加密简单实现方法
2017/08/25 PHP
PHP实现转盘抽奖算法分享
2020/04/15 PHP
event对象的方法 兼容多浏览器
2009/06/27 Javascript
javascript 操作符(~、&、|、^、)使用案例
2014/12/31 Javascript
JavaScript在浏览器标题栏上显示当前日期和时间的方法
2015/03/19 Javascript
AngularJS基础学习笔记之表达式
2015/05/10 Javascript
jsonp跨域请求数据实现手机号码查询实例分析
2015/12/12 Javascript
js实现页面跳转的五种方法推荐
2016/03/10 Javascript
jQuery Validation Engine验证控件调用外部函数验证的方法
2017/01/18 Javascript
JavaScript实现两个select下拉框选项左移右移
2017/03/09 Javascript
JS排序之冒泡排序详解
2017/04/08 Javascript
基于vue2的canvas时钟倒计时组件步骤解析
2018/11/05 Javascript
JavaScript常见继承模式实例小结
2019/01/11 Javascript
微信小程序页面间跳转传参方式总结
2019/06/13 Javascript
Vue中图片Src使用变量的方法
2019/10/30 Javascript
Vuex的热更替如何实现
2020/06/05 Javascript
vue 实现超长文本截取,悬浮框提示
2020/07/29 Javascript
Vue+element+cookie记住密码功能的简单实现方法
2020/09/20 Javascript
Python的Flask框架中Flask-Admin库的简单入门指引
2015/04/07 Python
Python实现二分查找与bisect模块详解
2017/01/13 Python
python修改list中所有元素类型的三种方法
2018/04/09 Python
使用Python横向合并excel文件的实例
2018/12/11 Python
python多进程下实现日志记录按时间分割
2019/07/22 Python
TensorFlow梯度求解tf.gradients实例
2020/02/04 Python
Python -m参数原理及使用方法解析
2020/08/21 Python
python中zip()函数遍历多个列表方法
2021/02/18 Python
HTML5不支持标签和新增标签详解
2016/06/27 HTML / CSS
吃透移动端 Html5 响应式布局
2019/12/16 HTML / CSS
奥林匹亚体育:Olympia Sports
2020/12/30 全球购物
Claire’s法国:时尚配饰、美容、珠宝、头发
2021/01/16 全球购物
2014年社区个人工作总结
2014/12/02 职场文书
商场圣诞节活动总结
2015/05/06 职场文书
银行岗位培训心得体会
2016/01/09 职场文书
HTML+CSS制作心跳特效的实现
2021/05/26 HTML / CSS