使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
Python的词法分析与语法分析
May 18 Python
Python内置函数Type()函数一个有趣的用法
Feb 18 Python
在RedHat系Linux上部署Python的Celery框架的教程
Apr 07 Python
使用Python的Twisted框架编写简单的网络客户端
Apr 16 Python
Python编程修改MP3文件名称的方法
Apr 19 Python
Python实现的快速排序算法详解
Aug 01 Python
浅谈Python traceback的优雅处理
Aug 31 Python
python实现列表的排序方法分享
Jul 01 Python
python中matplotlib实现随鼠标滑动自动标注代码
Apr 23 Python
python爬虫使用正则爬取网站的实现
Aug 03 Python
用python对excel查重
Dec 07 Python
Django对接elasticsearch实现全文检索的示例代码
Aug 02 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
深入PHP运行环境配置的详解
2013/06/04 PHP
解析mysql 表中的碎片产生原因以及清理
2013/06/22 PHP
一致性哈希算法以及其PHP实现详细解析
2013/08/24 PHP
PHP判断数据库中的记录是否存在的方法
2014/11/14 PHP
php使用CURL伪造IP和来源实例详解
2015/01/15 PHP
PHP导入导出Excel代码
2015/07/07 PHP
PHP数据库操作四:mongodb用法分析
2017/08/16 PHP
Laravel 已登陆用户再次查看登陆页面的自动跳转设置方法
2019/09/30 PHP
jQuery 表单验证扩展(三)
2010/10/20 Javascript
JS操作CSS随机改变网页背景实现思路
2014/03/10 Javascript
jquery解析xml字符串简单示例
2014/04/11 Javascript
JS使用for循环遍历Table的所有单元格内容
2014/08/21 Javascript
JavaScript indexOf方法入门实例(计算指定字符在字符串中首次出现的位置)
2014/10/17 Javascript
详解JS面向对象编程
2016/01/24 Javascript
Bootstrap树形组件jqTree的简单封装
2016/01/25 Javascript
Bootstrap每天必学之日期控制
2016/03/07 Javascript
jQuery 实现评论等级好评差评特效
2016/05/06 Javascript
Javascript类型系统之String字符串类型详解
2016/06/21 Javascript
利用node.js写一个爬取知乎妹纸图的小爬虫
2017/05/03 Javascript
jquery自定义显示消息数量
2017/12/19 jQuery
对angular4子路由&辅助路由详解
2018/10/09 Javascript
详解JavaScript中的函数、对象
2019/04/01 Javascript
vue 使用高德地图vue-amap组件过程解析
2019/09/07 Javascript
vue实现购物车案例
2020/05/30 Javascript
javascript自定义加载loading效果
2020/09/15 Javascript
python使用pyhook监控键盘并实现切换歌曲的功能
2014/07/18 Python
Python编程之序列操作实例详解
2017/07/22 Python
python实现redis三种cas事务操作
2017/12/19 Python
python实现一个点绕另一个点旋转后的坐标
2019/12/04 Python
美国购车网站:TrueCar
2016/10/19 全球购物
Lou & Grey美国官网:主打舒适性面料服饰
2017/12/21 全球购物
英国第一职业高尔夫商店:Clickgolf.co.uk
2020/11/18 全球购物
幼儿园教学随笔感言
2014/02/23 职场文书
小学二年级学生评语
2014/04/21 职场文书
2014年污水处理厂工作总结
2014/12/19 职场文书
分析SQL窗口函数之排名窗口函数
2022/04/21 Oracle