PyTorch加载自己的数据集实例详解


Posted in Python onMarch 18, 2020

数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力。 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能。为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

数据集存放大致有以下两种方式:

(1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg

root/cat_dog/cat.02.jpg

........................

root/cat_dog/dog.01.jpg

root/cat_dog/dog.02.jpg

......................

(2)不同类别的数据集放在不同目录下,目录名就是标签,数据集存放格式如下:

root/ants/xxx.png

root/ants/xxy.jpeg

root/ants/xxz.png

................

root/bees/123.jpg

root/bees/nsdf3.png

root/bees/asd932_.png

..................

1.1 对第1种数据集的处理步骤

(1)生成包含各文件名的列表(List)

(2)定义Dataset的一个子类,该子类需要继承Dataset类,查看Dataset类的源码

(3)重写父类Dataset中的两个魔法方法: 一个是: __lent__(self),其功能是len(Dataset),返回Dataset的样本数。 另一个是__getitem__(self,index),其功能假设索引为i,使Dataset[i]返回第i个样本。

(4)使用torch.utils.data.DataLoader加载数据集Dataset.

1.2 实例详解

以下以cat-dog数据集为例,说明如何实现自定义数据集的加载。

1.2.1 数据集结构

所有数据集在cat-dog目录下:

.\cat_dog\cat.01.jpg

.\cat_dog\cat.02.jpg

.\cat_dog\cat.03.jpg

....................

.\cat_dog\dog.01.jpg

.\cat_dog\dog.02.jpg

....................

1.2.2 导入需要用到的模块

from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import oimport torch
from torchvision import transforms, utils
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

1.2.3定义加载自定义数据的类

class MyDataset(Dataset): #继承Dataset
 def __init__(self, path_dir, transform=None): #初始化一些属性
  self.path_dir = path_dir #文件路径,如'.\data\cat-dog'
  self.transform = transform #对图形进行处理,如标准化、截取、转换等
  self.images = os.listdir(self.path_dir)#把路径下的所有文件放在一个列表中
 
 def __len__(self):#返回整个数据集的大小
  return len(self.images)
 
 def __getitem__(self,index):#根据索引index返回图像及标签
  image_index = self.images[index]#根据索引获取图像文件名称
  img_path = os.path.join(self.path_dir, image_index)#获取图像的路径或目录
  img = Image.open(img_path).convert('RGB')# 读取图像
    
  # 根据目录名称获取图像标签(cat或dog)
  label = img_path.split('\\')[-1].split('.')[0]
  #把字符转换为数字cat-0,dog-1
  label = 1 if 'dog' in label else 0
  
  if self.transform is not None:
   img = self.transform(img)
  return img,label

1.2.4 实例化类

dataset = MyDataset('.\data\cat-dog',transform=None)
img, label = dataset[0] #将启动魔法方法__getitem__(0)
print(type(img))
<class 'PIL.Image.Image'>

1.2.5 查看图像形状

i=1
for img, label in dataset:
    if i
img的形状(500, 374),label的值0

img的形状(300, 280),label的值0

img的形状(489, 499),label的值0

img的形状(431, 410),label的值0

img的形状(300, 224),label的值0

从上面返回样本的形状来看:

(1)每张图片的大小不一样,如果需要取batch训练的神经网络来说很不友好。

(2)返回样本的数值较大,未归一化至[-1, 1]

为此需要对img进行转换,如何转换?只要使用torchvision中的transforms即可

1.2.6 对图像数据进行处理

这里使用torchvision中的transforms模块

from torchvision import transforms as T
transform = T.Compose([
 T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
 T.CenterCrop(224), # 从图片中间切出224*224的图片
 T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])

1.2.7查看处理后的数据

dataset = MyDataset('.\data\cat-dog',transform=transform)
for img, label in dataset: 
 print("图像img的形状{},标签label的值{}".format(img.shape, label))
 print("图像数据预处理后:\n",img)
 break

图像img的形状torch.Size([3, 224, 224]),标签label的值0

图像数据预处理后:

tensor([[[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9529, 0.9529, 0.9529],

...,

[-0.4824, -0.5294, -0.5373, ..., -0.9216, -0.9294, -0.9451],

[-0.4980, -0.5529, -0.5608, ..., -0.9294, -0.9373, -0.9529],

[-0.4980, -0.5529, -0.5686, ..., -0.9529, -0.9608, -0.9608]],

[[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.8039, 0.7961, 0.7961],

...,

[-0.6078, -0.6471, -0.6549, ..., -0.9137, -0.9216, -0.9373],

[-0.6157, -0.6706, -0.6784, ..., -0.9216, -0.9294, -0.9451],

[-0.6157, -0.6706, -0.6863, ..., -0.9451, -0.9529, -0.9529]],

[[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2157, 0.2235, 0.2235],

...,

[-0.9529, -0.9843, -0.9922, ..., -0.9529, -0.9608, -0.9765],

[-0.9686, -0.9922, -1.0000, ..., -0.9608, -0.9686, -0.9843],

[-0.9686, -0.9922, -1.0000, ..., -0.9843, -0.9922, -0.9922]]])

由此可知,数据已标准化、规范化。

1.2.8对数据集进行批量加载

使用DataLoader模块,对数据集dataset进行批量加载

#使用DataLoader加载数据
dataloader = DataLoader(dataset,batch_size=4,shuffle=True)
for batch_datas, batch_labels in dataloader:
 print(batch_datas.size(),batch_labels.size())
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([2, 3, 224, 224]) torch.Size([2])

1.2.9随机查看一个批次的图像

import torchvision
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# 显示图像
def imshow(img):
 img = img / 2 + 0.5  # unnormalize
 npimg = img.numpy()
 plt.imshow(np.transpose(npimg, (1, 2, 0)))
 plt.show()
# 随机获取部分训练数据
dataiter = iter(dataloader)
images, labels = dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))

2 对第2种数据集的处理

处理这种情况比较简单,可分为2步:

(1)使用datasets.ImageFolder读取、处理图像。

(2)使用.data.DataLoader批量加载数据集,示例如下:

import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
  transforms.RandomSizedCrop(224),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
 ])
hymenoptera_dataset = datasets.ImageFolder(root='.\catdog\train',
           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,

总结

到此这篇关于PyTorch加载自己的数据集实例详解的文章就介绍到这了,更多相关PyTorch加载 数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
使用Python的Twisted框架编写简单的网络客户端
Apr 16 Python
python脚本实现数据导出excel格式的简单方法(推荐)
Dec 30 Python
Python调用C语言的方法【基于ctypes模块】
Jan 22 Python
对Python中Iterator和Iterable的区别详解
Oct 18 Python
Pycharm 实现下一个文件引用另外一个文件的方法
Jan 17 Python
python 获取微信好友列表的方法(微信web)
Feb 21 Python
python3利用Socket实现通信的方法示例
May 06 Python
python开发实例之Python的Twisted框架中Deferred对象的详细用法与实例
Mar 19 Python
Python 如何创建一个线程池
Jul 28 Python
python tkinter的消息框模块(messagebox,simpledialog)
Nov 07 Python
python中yield的用法详解
Jan 13 Python
OpenCV全景图像拼接的实现示例
Jun 05 Python
Python进程间通信multiprocess代码实例
Mar 18 #Python
python实现超级玛丽游戏
Mar 18 #Python
python实现超级马里奥
Mar 18 #Python
Python开发企业微信机器人每天定时发消息实例
Mar 17 #Python
10个python3常用排序算法详细说明与实例(快速排序,冒泡排序,桶排序,基数排序,堆排序,希尔排序,归并排序,计数排序)
Mar 17 #Python
Python Selenium安装及环境配置的实现
Mar 17 #Python
详解python环境安装selenium和手动下载安装selenium的方法
Mar 17 #Python
You might like
SONY ICF-SW55的电路分析
2021/03/02 无线电
提高PHP性能的编码技巧以及性能优化详细解析
2013/08/24 PHP
PHP 验证登陆类分享
2015/03/13 PHP
学习php设计模式 php实现命令模式(command)
2015/12/08 PHP
Laravel5.1 框架路由基础详解
2020/01/04 PHP
php实现通过stomp协议连接ActiveMQ操作示例
2020/02/23 PHP
js控制的回到页面顶端goTop的代码实现
2013/03/20 Javascript
利用JS解决ie6不支持max-width,max-height问题的方法
2014/01/02 Javascript
jquery实现页面关键词高亮显示的方法
2015/03/12 Javascript
JS实现双击屏幕滚动效果代码
2015/10/28 Javascript
jQuery根据name属性进行查找的用法分析
2016/06/23 Javascript
AngularJS通过$location获取及改变当前页面的URL
2016/09/23 Javascript
Yarn的安装与使用详细介绍
2016/10/25 Javascript
jQuery时间日期三级联动(推荐)
2016/11/27 Javascript
如何进行微信公众号开发的本地调试的方法
2019/06/16 Javascript
jquery实现垂直无限轮播的方法分析
2019/07/16 jQuery
vue子路由跳转实现tab选项卡
2019/07/24 Javascript
微信小程序获取位置展示地图并标注信息的实例代码
2019/09/01 Javascript
vue动态合并单元格并添加小计合计功能示例
2020/11/26 Vue.js
Python实现将数据库一键导出为Excel表格的实例
2016/12/30 Python
python编写微信远程控制电脑的程序
2018/01/05 Python
Python中数组,列表:冒号的灵活用法介绍(np数组,列表倒序)
2018/04/18 Python
Python socket实现的简单通信功能示例
2018/08/21 Python
用pycharm开发django项目示例代码
2018/10/24 Python
Python列表(List)知识点总结
2019/02/18 Python
Django实现文件上传下载
2019/10/06 Python
Python使用Opencv实现图像特征检测与匹配的方法
2019/10/30 Python
芬兰灯具网上商店:Nettilamppu.fi
2018/06/30 全球购物
Java平台和其他软件平台有什么不同
2015/06/05 面试题
预备党员转正思想汇报
2014/01/12 职场文书
工程招投标邀请书
2014/01/26 职场文书
有关九一八事变的演讲稿
2014/09/14 职场文书
2015年销售员工作总结范文
2015/04/07 职场文书
MySQL中日期型单行函数代码详解
2021/06/21 MySQL
「玫瑰之王的葬礼」舞台剧主视觉图公开
2022/03/21 日漫
「约定的梦幻岛」作画发布诺曼生日新绘
2022/03/21 日漫