PyTorch加载自己的数据集实例详解


Posted in Python onMarch 18, 2020

数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力。 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能。为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

数据集存放大致有以下两种方式:

(1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg

root/cat_dog/cat.02.jpg

........................

root/cat_dog/dog.01.jpg

root/cat_dog/dog.02.jpg

......................

(2)不同类别的数据集放在不同目录下,目录名就是标签,数据集存放格式如下:

root/ants/xxx.png

root/ants/xxy.jpeg

root/ants/xxz.png

................

root/bees/123.jpg

root/bees/nsdf3.png

root/bees/asd932_.png

..................

1.1 对第1种数据集的处理步骤

(1)生成包含各文件名的列表(List)

(2)定义Dataset的一个子类,该子类需要继承Dataset类,查看Dataset类的源码

(3)重写父类Dataset中的两个魔法方法: 一个是: __lent__(self),其功能是len(Dataset),返回Dataset的样本数。 另一个是__getitem__(self,index),其功能假设索引为i,使Dataset[i]返回第i个样本。

(4)使用torch.utils.data.DataLoader加载数据集Dataset.

1.2 实例详解

以下以cat-dog数据集为例,说明如何实现自定义数据集的加载。

1.2.1 数据集结构

所有数据集在cat-dog目录下:

.\cat_dog\cat.01.jpg

.\cat_dog\cat.02.jpg

.\cat_dog\cat.03.jpg

....................

.\cat_dog\dog.01.jpg

.\cat_dog\dog.02.jpg

....................

1.2.2 导入需要用到的模块

from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import oimport torch
from torchvision import transforms, utils
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

1.2.3定义加载自定义数据的类

class MyDataset(Dataset): #继承Dataset
 def __init__(self, path_dir, transform=None): #初始化一些属性
  self.path_dir = path_dir #文件路径,如'.\data\cat-dog'
  self.transform = transform #对图形进行处理,如标准化、截取、转换等
  self.images = os.listdir(self.path_dir)#把路径下的所有文件放在一个列表中
 
 def __len__(self):#返回整个数据集的大小
  return len(self.images)
 
 def __getitem__(self,index):#根据索引index返回图像及标签
  image_index = self.images[index]#根据索引获取图像文件名称
  img_path = os.path.join(self.path_dir, image_index)#获取图像的路径或目录
  img = Image.open(img_path).convert('RGB')# 读取图像
    
  # 根据目录名称获取图像标签(cat或dog)
  label = img_path.split('\\')[-1].split('.')[0]
  #把字符转换为数字cat-0,dog-1
  label = 1 if 'dog' in label else 0
  
  if self.transform is not None:
   img = self.transform(img)
  return img,label

1.2.4 实例化类

dataset = MyDataset('.\data\cat-dog',transform=None)
img, label = dataset[0] #将启动魔法方法__getitem__(0)
print(type(img))
<class 'PIL.Image.Image'>

1.2.5 查看图像形状

i=1
for img, label in dataset:
    if i
img的形状(500, 374),label的值0

img的形状(300, 280),label的值0

img的形状(489, 499),label的值0

img的形状(431, 410),label的值0

img的形状(300, 224),label的值0

从上面返回样本的形状来看:

(1)每张图片的大小不一样,如果需要取batch训练的神经网络来说很不友好。

(2)返回样本的数值较大,未归一化至[-1, 1]

为此需要对img进行转换,如何转换?只要使用torchvision中的transforms即可

1.2.6 对图像数据进行处理

这里使用torchvision中的transforms模块

from torchvision import transforms as T
transform = T.Compose([
 T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
 T.CenterCrop(224), # 从图片中间切出224*224的图片
 T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])

1.2.7查看处理后的数据

dataset = MyDataset('.\data\cat-dog',transform=transform)
for img, label in dataset: 
 print("图像img的形状{},标签label的值{}".format(img.shape, label))
 print("图像数据预处理后:\n",img)
 break

图像img的形状torch.Size([3, 224, 224]),标签label的值0

图像数据预处理后:

tensor([[[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9529, 0.9529, 0.9529],

...,

[-0.4824, -0.5294, -0.5373, ..., -0.9216, -0.9294, -0.9451],

[-0.4980, -0.5529, -0.5608, ..., -0.9294, -0.9373, -0.9529],

[-0.4980, -0.5529, -0.5686, ..., -0.9529, -0.9608, -0.9608]],

[[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.8039, 0.7961, 0.7961],

...,

[-0.6078, -0.6471, -0.6549, ..., -0.9137, -0.9216, -0.9373],

[-0.6157, -0.6706, -0.6784, ..., -0.9216, -0.9294, -0.9451],

[-0.6157, -0.6706, -0.6863, ..., -0.9451, -0.9529, -0.9529]],

[[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2157, 0.2235, 0.2235],

...,

[-0.9529, -0.9843, -0.9922, ..., -0.9529, -0.9608, -0.9765],

[-0.9686, -0.9922, -1.0000, ..., -0.9608, -0.9686, -0.9843],

[-0.9686, -0.9922, -1.0000, ..., -0.9843, -0.9922, -0.9922]]])

由此可知,数据已标准化、规范化。

1.2.8对数据集进行批量加载

使用DataLoader模块,对数据集dataset进行批量加载

#使用DataLoader加载数据
dataloader = DataLoader(dataset,batch_size=4,shuffle=True)
for batch_datas, batch_labels in dataloader:
 print(batch_datas.size(),batch_labels.size())
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([2, 3, 224, 224]) torch.Size([2])

1.2.9随机查看一个批次的图像

import torchvision
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# 显示图像
def imshow(img):
 img = img / 2 + 0.5  # unnormalize
 npimg = img.numpy()
 plt.imshow(np.transpose(npimg, (1, 2, 0)))
 plt.show()
# 随机获取部分训练数据
dataiter = iter(dataloader)
images, labels = dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))

2 对第2种数据集的处理

处理这种情况比较简单,可分为2步:

(1)使用datasets.ImageFolder读取、处理图像。

(2)使用.data.DataLoader批量加载数据集,示例如下:

import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
  transforms.RandomSizedCrop(224),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
 ])
hymenoptera_dataset = datasets.ImageFolder(root='.\catdog\train',
           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,

总结

到此这篇关于PyTorch加载自己的数据集实例详解的文章就介绍到这了,更多相关PyTorch加载 数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
rhythmbox中文名乱码问题解决方法
Sep 06 Python
python爬虫_实现校园网自动重连脚本的教程
Apr 22 Python
python3中zip()函数使用详解
Jun 29 Python
Python+pyplot绘制带文本标注的柱状图方法
Jul 08 Python
详解python中的生成器、迭代器、闭包、装饰器
Aug 22 Python
Python3并发写文件与Python对比
Nov 20 Python
Python类反射机制使用实例解析
Dec 30 Python
Python多线程实现支付模拟请求过程解析
Apr 21 Python
完美解决python针对hdfs上传和下载的问题
Jun 05 Python
Python pip安装第三方库实现过程解析
Jul 09 Python
python使用隐式循环快速求和的实现示例
Sep 11 Python
python Tornado框架的使用示例
Oct 19 Python
Python进程间通信multiprocess代码实例
Mar 18 #Python
python实现超级玛丽游戏
Mar 18 #Python
python实现超级马里奥
Mar 18 #Python
Python开发企业微信机器人每天定时发消息实例
Mar 17 #Python
10个python3常用排序算法详细说明与实例(快速排序,冒泡排序,桶排序,基数排序,堆排序,希尔排序,归并排序,计数排序)
Mar 17 #Python
Python Selenium安装及环境配置的实现
Mar 17 #Python
详解python环境安装selenium和手动下载安装selenium的方法
Mar 17 #Python
You might like
PHP-CGI进程CPU 100% 与 file_get_contents 函数的关系分析
2011/08/15 PHP
PHP 设计模式之观察者模式介绍
2012/02/22 PHP
PHP之短标签开启设置
2013/06/17 PHP
如何在Ubuntu下启动Apache的Rewrite功能
2013/07/05 PHP
完美利用Yii2微信后台开发的系列总结
2016/07/18 PHP
不要小看注释掉的JS 引起的安全问题
2008/12/27 Javascript
关于JS管理作用域的问题
2013/04/10 Javascript
jquery.form.js实现将form提交转为ajax方式提交的方法
2015/04/07 Javascript
Jquery和angularjs获取check框选中的值的方法汇总
2016/01/17 Javascript
Windows环境下npm install 报错: operation not permitted, rename的解决方法
2016/09/26 Javascript
解决vue中修改export default中脚本报一大堆错的问题
2018/08/27 Javascript
JavaScript中AOP的实现与应用
2019/05/06 Javascript
jquery 插件重新绑定的处理方法分析
2019/11/23 jQuery
ES6 async、await的基本使用方法示例
2020/06/06 Javascript
[03:59]DOTA2英雄梦之声_第07期_水晶室女
2014/06/23 DOTA
使用Python的内建模块collections的教程
2015/04/28 Python
Python随机数用法实例详解【基于random模块】
2017/04/18 Python
Python基础语言学习笔记总结(精华)
2017/11/14 Python
每天迁移MySQL历史数据到历史库Python脚本
2018/04/13 Python
python 对dataframe下面的值进行大规模赋值方法
2018/06/09 Python
python自带tkinter库实现棋盘覆盖图形界面
2019/07/17 Python
Python 最强编辑器详细使用指南(PyCharm )
2019/09/16 Python
python函数中将变量名转换成字符串实例
2020/05/11 Python
python脚本定时发送邮件
2020/12/22 Python
CSS3实现可关闭的下拉手风琴菜单效果
2015/08/31 HTML / CSS
html5时钟实现代码
2010/10/22 HTML / CSS
受希腊女神灵感的晚礼服、鸡尾酒礼服和婚纱:THEIA
2018/04/15 全球购物
英国名牌服装购物网站:OD’s Designer
2019/09/02 全球购物
职称自我鉴定
2013/10/15 职场文书
计划生育标语
2014/06/23 职场文书
依法行政工作汇报
2014/10/28 职场文书
2014年英语教师工作总结
2014/12/03 职场文书
毕业生入职感言
2015/07/31 职场文书
导游词之江西赣州
2019/10/15 职场文书
python迷宫问题深度优先遍历实例
2021/06/20 Python
Vue elementUI表单嵌套表格并对每行进行校验详解
2022/02/18 Vue.js