PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python函数式编程指南(一):函数式编程概述
Jun 24 Python
Python模拟百度登录实例详解
Jan 20 Python
简单谈谈python中的Queue与多进程
Aug 25 Python
Python中常用信号signal类型实例
Jan 25 Python
Python面向对象程序设计类的封装与继承用法示例
Apr 12 Python
Python实现决策树并且使用Graphviz可视化的例子
Aug 09 Python
基于Python中的yield表达式介绍
Nov 19 Python
pandas中read_csv的缺失值处理方式
Dec 19 Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 Python
Django如何批量创建Model
Sep 01 Python
15款Python编辑器的优缺点,别再问我“选什么编辑器”啦
Oct 19 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
强烈声明: 不要使用(include/require)_once
2013/06/06 PHP
解析file_get_contents模仿浏览器头(user_agent)获取数据
2013/06/27 PHP
Zend Framework入门教程之Zend_Session会话操作详解
2016/12/08 PHP
Laravel 关联模型-关联新增和关联更新的方法
2019/10/10 PHP
关于query Javascript CSS Selector engine
2013/04/12 Javascript
Js中获取frames中的元素示例代码
2013/07/30 Javascript
javascript中RegExp保留小数点后几位数的方法分享
2013/08/13 Javascript
JS 实现导航栏悬停效果(续2)
2013/09/24 Javascript
JQuery自适应窗口大小导航菜单附源码下载
2015/09/01 Javascript
js学习总结之DOM2兼容处理重复问题的解决方法
2017/07/27 Javascript
Angular4学习笔记之实现绑定和分包
2017/08/01 Javascript
js模块加载方式浅析
2017/08/12 Javascript
Angular2 组件交互实例详解
2017/08/24 Javascript
VeeValidate在vue项目里表单校验应用案例
2018/05/09 Javascript
Vue项目中最新用到的一些实用小技巧
2018/11/06 Javascript
小程序最新获取用户昵称和头像的方法总结
2019/09/23 Javascript
[01:11:46]DOTA2-DPC中国联赛 正赛 iG vs Magma BO3 第一场 2月23日
2021/03/11 DOTA
更改Python命令行交互提示符的方法
2015/01/14 Python
Django 权限认证(根据不同的用户,设置不同的显示和访问权限)
2019/07/24 Python
更新pip3与pyttsx3文字语音转换的实现方法
2019/08/08 Python
Python List列表对象内置方法实例详解
2019/10/22 Python
浅谈django 重载str 方法
2020/05/19 Python
使用Dajngo 通过代码添加xadmin用户和权限(组)
2020/07/03 Python
python 30行代码实现蚂蚁森林自动偷能量
2021/02/08 Python
定制iPhone和Macbook保护壳:Slick Case
2018/11/21 全球购物
Noon埃及:埃及在线购物
2019/11/26 全球购物
英国门销售网站:Green Tree Doors
2020/01/07 全球购物
分解成质因数(如435234=251*17*17*3*2,据说是华为笔试题)
2014/07/16 面试题
如何用SQL语句进行模糊查找
2015/09/25 面试题
公关关系专员的自我评价分享
2013/11/20 职场文书
工作作风懒散检讨书
2014/10/29 职场文书
教师继续教育反思周记
2015/06/25 职场文书
2015小学师德工作总结
2015/07/21 职场文书
《倍数和因数》教学反思
2016/02/23 职场文书
vue+spring boot实现校验码功能
2021/05/27 Vue.js
golang生成vcf通讯录格式文件详情
2022/03/25 Golang