PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中__init__和__new__的区别详解
Jul 09 Python
Python随机生成带特殊字符的密码
Mar 02 Python
TensorFlow利用saver保存和提取参数的实例
Jul 26 Python
python web框架中实现原生分页
Sep 08 Python
TensorFlow索引与切片的实现方法
Nov 20 Python
python3 常见解密加密算法实例分析【base64、MD5等】
Dec 19 Python
Python GUI库PyQt5图形和特效样式QSS介绍
Feb 25 Python
pycharm中导入模块错误时提示Try to run this command from the system terminal
Mar 26 Python
基于python调用jenkins-cli实现快速发布
Aug 14 Python
Python下使用Trackbar实现绘图板
Oct 27 Python
python Polars库的使用简介
Apr 21 Python
Python提取PDF指定内容并生成新文件
Jun 09 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
在apache下限制每个虚拟主机的并发数!!!!
2006/10/09 PHP
php 随机记录mysql rand()造成CPU 100%的解决办法
2010/05/18 PHP
解析PHP可变函数的经典用法
2013/06/20 PHP
php中sql注入漏洞示例 sql注入漏洞修复
2014/01/24 PHP
PHP基于MySQL数据库实现对象持久层的方法
2015/06/17 PHP
php ajax实现文件上传进度条
2016/03/29 PHP
Javascript-Mozilla和IE中的一个函数直接量的问题分析
2007/08/12 Javascript
jquery 操作日期、星期、元素的追加的实现代码
2012/02/07 Javascript
javascript获取所有同类checkbox选项(实例代码)
2013/11/07 Javascript
jQuery select表单提交省市区城市三级联动核心代码
2014/06/09 Javascript
JS实现网页滚动条感应鼠标变色的方法
2015/02/26 Javascript
JavaScript之数组(Array)详解
2015/04/01 Javascript
浅谈JavaScript中的apply/call/bind和this的使用
2017/02/26 Javascript
PHP实现本地图片上传和验证功能
2017/02/27 Javascript
vue数字类型过滤器的示例代码
2017/09/07 Javascript
JavaScript重复元素处理方法分析【统计个数、计算、去重复等】
2017/12/14 Javascript
js数组方法reduce经典用法代码分享
2018/01/07 Javascript
微信JSSDK实现打开摄像头拍照再将相片保存到服务器
2019/11/15 Javascript
nuxt 每个页面head标签内容设置方式
2020/11/05 Javascript
[01:41]DOTA2超级联赛专访YYF 称一辈子难忘TI2
2013/05/28 DOTA
python关闭windows进程的方法
2015/04/18 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
2019/08/20 Python
python sklearn包——混淆矩阵、分类报告等自动生成方式
2020/02/28 Python
python代数式括号有效性检验示例代码
2020/10/04 Python
canvas实现有递增动画的环形进度条的实现方法
2019/07/10 HTML / CSS
10条PHP编程习惯
2014/05/26 面试题
动物学专业毕业生求职信
2013/10/11 职场文书
工作个人的自我评价
2014/01/14 职场文书
新闻编辑自荐书范文
2014/02/12 职场文书
运动会入场词60字
2014/02/15 职场文书
内蒙古鄂尔多斯市市长寄语
2014/04/10 职场文书
典型事迹材料范文
2014/12/29 职场文书
Go使用协程交替打印字符
2021/04/29 Golang
python字符串常规操作大全
2021/05/02 Python
Python re.sub 反向引用的实现
2021/07/07 Python
box-shadow单边阴影的实现
2023/05/21 HTML / CSS