PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
windows下wxPython开发环境安装与配置方法
Jun 28 Python
用Python解析XML的几种常见方法的介绍
Apr 09 Python
python使用MySQLdb访问mysql数据库的方法
Aug 03 Python
python如何统计序列中元素
Jul 31 Python
pandas将numpy数组写入到csv的实例
Jul 04 Python
使用pandas实现csv/excel sheet互相转换的方法
Dec 10 Python
Python数据类型之String字符串实例详解
May 08 Python
python批量爬取下载抖音视频
Jun 17 Python
详解python中的time和datetime的常用方法
Jul 08 Python
Python3 使用pillow库生成随机验证码
Aug 26 Python
python使用PIL剪切和拼接图片
Mar 23 Python
python代码如何注释
Jun 01 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
强烈推荐:php.ini中文版(2)
2006/10/09 PHP
php学习 字符串课件
2008/06/15 PHP
善用事件代理,警惕闭包的性能陷阱。
2011/01/20 Javascript
JavaScript实现GriwView单列全选(自写代码)
2013/05/13 Javascript
JavaScript作用域链使用介绍
2013/08/29 Javascript
JavaScript实现的一个计算数字步数的算法分享
2014/12/06 Javascript
Javascript中使用parseInt函数需要注意的问题
2015/04/02 Javascript
JavaScript中split() 使用方法汇总
2015/04/17 Javascript
js实现跨域的方法实例详解
2015/06/24 Javascript
JavaScript资源预加载组件和滑屏组件的使用推荐
2016/03/10 Javascript
JQuery统计input和textarea文字输入数量(代码分享)
2016/12/29 Javascript
使用vue.js写一个tab选项卡效果
2017/03/25 Javascript
jQuery插件FusionCharts绘制2D柱状图和折线图的组合图效果示例【附demo源码】
2017/04/10 jQuery
Swiper 4.x 使用方法(移动端网站的内容触摸滑动)
2018/05/17 Javascript
vue-cli的工程模板与构建工具详解
2018/09/27 Javascript
微信小程序登录数据解密及状态维持实例详解
2019/05/06 Javascript
vue-cli脚手架的.babelrc文件用法说明
2020/09/11 Javascript
Python、Javascript中的闭包比较
2015/02/04 Python
python3 图片referer防盗链的实现方法
2018/03/12 Python
django的登录注册系统的示例代码
2018/05/14 Python
使用selenium模拟登录解决滑块验证问题的实现
2019/05/10 Python
Python图像处理PIL各模块详细介绍(推荐)
2019/07/17 Python
python 非线性规划方式(scipy.optimize.minimize)
2020/02/11 Python
python实现从ftp服务器下载文件
2020/03/03 Python
html5定制表单_动力节点Java学院整理
2017/07/11 HTML / CSS
乐天旅游台湾网站:Rakuten Travel TW
2017/06/01 全球购物
VICHY薇姿英国官网:全球专业敏感肌护肤领先品牌
2017/07/04 全球购物
美国知名的旅游网站:OneTravel
2018/10/09 全球购物
介绍一下EJB的分类及其各自的功能及应用
2016/08/23 面试题
中间件的定义
2016/08/09 面试题
python+selenium小米商城红米K40手机自动抢购的示例代码
2021/03/24 Python
环境工程求职简历的自我评价范文
2013/10/24 职场文书
《兰亭集序》教学反思
2014/02/11 职场文书
义和团口号
2014/06/17 职场文书
Mysql 设置boolean类型的操作
2021/06/04 MySQL
《火纹风花雪月无双》预告“神秘雇佣兵” 紫发剑客
2022/04/13 其他游戏