PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
详解在Python程序中解析并修改XML内容的方法
Nov 16 Python
python开发环境PyScripter中文乱码问题解决方案
Sep 11 Python
python dict 字典 以及 赋值 引用的一些实例(详解)
Jan 20 Python
对numpy中数组转置的求解以及向量内积计算方法
Oct 31 Python
对pandas读取中文unicode的csv和添加行标题的方法详解
Dec 12 Python
使用Python计算玩彩票赢钱概率
Jun 26 Python
Python3+PyInstall+Sciter解决报错缺少dll、html等文件问题
Jul 15 Python
python2和python3实现在图片上加汉字的方法
Aug 22 Python
python脚本之一键移动自定格式文件方法实例
Sep 02 Python
django框架forms组件用法实例详解
Dec 10 Python
PyTorch 随机数生成占用 CPU 过高的解决方法
Jan 13 Python
PyInstaller运行原理及常用操作详解
Jun 13 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
PHP 抓取新浪读书频道的小说并生成txt电子书的代码
2009/12/18 PHP
了解Joomla 这款来自国外的php网站管理系统
2010/03/11 PHP
php中记录用户访问过的产品,在cookie记录产品id,id取得产品信息
2011/05/04 PHP
php中base64_decode与base64_encode加密解密函数实例
2014/11/24 PHP
php实现倒计时效果
2015/12/19 PHP
CI框架数据库查询之join用法分析
2016/05/18 PHP
Yii框架实现多数据库配置和操作的方法
2017/05/25 PHP
jquery日历控件实现方法分享
2014/03/07 Javascript
JQuery表单验证插件EasyValidator用法分析
2014/11/15 Javascript
jquery选择器简述
2015/08/31 Javascript
学习JavaScript设计模式之装饰者模式
2016/01/19 Javascript
基于jQuery的checkbox全选问题分析
2016/11/18 Javascript
js 原型对象和原型链理解
2017/02/09 Javascript
jquery设置css样式的多种方法(总结)
2017/02/21 Javascript
discuz表情的JS提取方法分析
2017/03/22 Javascript
详解webpack require.ensure与require AMD的区别
2017/12/13 Javascript
vue axios 给生产环境和发布环境配置不同的接口地址(推荐)
2018/05/08 Javascript
Vue使用vux-ui自定义表单验证遇到的问题及解决方法
2018/05/10 Javascript
vue 更改连接后台的api示例
2019/11/11 Javascript
JS图片懒加载的优点及实现原理
2020/01/10 Javascript
JavaScript 链表定义与使用方法示例
2020/04/28 Javascript
Python continue语句用法实例
2014/03/11 Python
浅谈python中的__init__、__new__和__call__方法
2017/07/18 Python
Python实现决策树C4.5算法的示例
2018/05/30 Python
python登录WeChat 实现自动回复实例详解
2019/05/28 Python
关于python3中setup.py小概念解析
2019/08/22 Python
python 列表推导式使用详解
2019/08/29 Python
Python如何实现大型数组运算(使用NumPy)
2020/07/24 Python
微软开源最强Python自动化神器Playwright(不用写一行代码)
2021/01/05 Python
Tirendo比利时:在线购买轮胎
2018/10/22 全球购物
Alexandre Birman美国官网:亚历山大·伯曼
2019/10/30 全球购物
毕业生个人的求职信范文
2013/12/03 职场文书
文明家庭先进事迹材
2014/01/27 职场文书
学生会副主席竞聘书
2014/03/31 职场文书
《假如》教学反思
2014/04/17 职场文书
体育运动口号
2014/06/09 职场文书