使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
Python 搭建Web站点之Web服务器与Web框架
Nov 06 Python
Python使用中文正则表达式匹配指定中文字符串的方法示例
Jan 20 Python
Python2.7编程中SQLite3基本操作方法示例
Aug 09 Python
python 实现视频流下载保存MP4的方法
Jan 09 Python
python打印9宫格、25宫格等奇数格 满足横竖斜相加和相等
Jul 19 Python
解决django 向mysql中写入中文字符出错的问题
May 18 Python
python实现爱奇艺登陆密码RSA加密的方法示例详解
May 27 Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 Python
Python爬虫与反爬虫大战
Jul 30 Python
python字典与json转换的方法总结
Dec 28 Python
Python之qq自动发消息的示例代码
Feb 18 Python
如何用六步教会你使用python爬虫爬取数据
Apr 06 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
利用PHP创建动态图像
2006/10/09 PHP
Zend Guard一些常见问题解答
2008/09/11 PHP
解析php防止form重复提交的方法
2013/07/01 PHP
php socket实现的聊天室代码分享
2014/08/16 PHP
实例讲解PHP面向对象之多态
2014/08/20 PHP
SSO单点登录的PHP实现方法(Laravel框架)
2016/03/23 PHP
PHP实现的字符串匹配算法示例【sunday算法】
2017/12/19 PHP
PHP simplexml_import_dom()函数讲解
2019/02/03 PHP
javascript下判断一个元素是否存在的代码
2010/03/05 Javascript
EasyUi tabs的高度与宽度根据IE窗口的变化自适应代码
2010/10/26 Javascript
javascript中expression的用法整理
2014/05/13 Javascript
利用Vue.js指令实现全选功能
2016/09/08 Javascript
Bootstrap学习笔记之环境配置(1)
2016/12/07 Javascript
基于Vue实现tab栏切换内容不断实时刷新数据功能
2017/04/13 Javascript
Angular2使用jQuery的方法教程
2017/05/28 jQuery
网页中的图片查看器viewjs使用方法
2017/07/11 Javascript
利用JavaScript如何查询某个值是否数组内
2017/07/30 Javascript
基于jQuery实现的单行公告活动轮播效果
2017/08/23 jQuery
Node.js引入UIBootstrap的方法示例
2018/05/11 Javascript
独立部署小程序基于nodejs的服务器过程详解
2019/06/24 NodeJs
python实现获取Ip归属地等信息
2016/08/27 Python
pycharm中import呈现灰色原因的解决方法
2020/03/04 Python
Django如何使用redis作为缓存
2020/05/21 Python
通俗讲解python 装饰器
2020/09/07 Python
香港草莓网土耳其网站:Strawberrynet TR
2017/03/02 全球购物
阿联酋网上花店:Ferns N Petals
2018/02/14 全球购物
Farfetch巴西官网:奢侈品牌时尚购物平台
2020/10/19 全球购物
秘书岗位职责
2013/11/18 职场文书
逃课上网检讨书
2014/02/20 职场文书
语文课外活动总结
2014/08/27 职场文书
单位工作证明书格式
2014/10/04 职场文书
2015年电工工作总结
2015/04/10 职场文书
爱心捐款活动总结
2015/05/09 职场文书
学雷锋广播稿大全
2015/08/19 职场文书
导游词之桂林
2019/08/20 职场文书
python 使用tkinter与messagebox写界面和弹窗
2022/03/20 Python