Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python函数学习笔记
Oct 07 Python
Python利用Nagios增加微信报警通知的功能
Feb 18 Python
Python爬虫利用cookie实现模拟登陆实例详解
Jan 12 Python
Python数据结构之单链表详解
Sep 12 Python
python requests更换代理适用于IP频率限制的方法
Aug 21 Python
pycharm 设置项目的根目录教程
Feb 12 Python
python通用读取vcf文件的类(复制粘贴即可用)
Feb 29 Python
python pyqtgraph 保存图片到本地的实例
Mar 14 Python
Python 实现将某一列设置为str类型
Jul 14 Python
python 用opencv实现霍夫线变换
Nov 27 Python
使用Python实现音频双通道分离
Dec 25 Python
使用python实现学生信息管理系统
Feb 25 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
smarty自定义函数用法示例
2016/05/20 PHP
php实现压缩合并js的方法【附demo源码下载】
2016/09/22 PHP
PHP中Cookie的使用详解(简单易懂)
2017/04/28 PHP
PHP Swoole异步Redis客户端实现方法示例
2019/10/24 PHP
javascript之Partial Application学习
2013/01/10 Javascript
javascript实现yield的方法
2013/11/06 Javascript
Google (Local) Search API的简单使用介绍
2013/11/28 Javascript
jQuery自带的一些常用方法总结
2014/09/03 Javascript
每天一篇javascript学习小结(Date对象)
2015/11/13 Javascript
关于JS变量和作用域详解
2016/07/28 Javascript
Javascript 闭包详解及实例代码
2016/11/30 Javascript
微信小程序 获取二维码实例详解
2017/06/23 Javascript
详解用node.js实现简单的反向代理
2017/06/26 Javascript
JS声明对象时属性名加引号与不加引号的问题及解决方法
2018/02/16 Javascript
原生JS+HTML5实现的可调节写字板功能示例
2018/08/30 Javascript
微信小程序地图(map)组件点击(tap)获取经纬度的方法
2019/01/10 Javascript
简单了解Vue computed属性及watch区别
2020/07/10 Javascript
vue中选中多个选项并且改变选中的样式的实例代码
2020/09/16 Javascript
python基于mysql实现的简单队列以及跨进程锁实例详解
2014/07/07 Python
python分割列表(list)的方法示例
2017/05/07 Python
利用Python实现Windows下的鼠标键盘模拟的实例代码
2017/07/13 Python
对python实时得到鼠标位置的示例讲解
2018/10/14 Python
python列表推导式操作解析
2019/11/26 Python
Django自带的用户验证系统实现
2020/12/18 Python
农贸市场管理制度
2014/01/31 职场文书
关于母亲节的感言
2014/02/04 职场文书
大学生全国两会报告感想
2014/03/17 职场文书
招股说明书范本
2014/05/06 职场文书
校园活动策划方案
2014/06/13 职场文书
新闻学专业求职信
2014/07/28 职场文书
教师群众路线学习心得体会
2014/11/04 职场文书
2014年工商所工作总结
2014/12/09 职场文书
小学班主任事迹材料
2014/12/17 职场文书
结婚通知短信怎么写
2015/04/17 职场文书
mysql下的max_allowed_packet参数设置详解
2022/02/12 MySQL
Python实现Hash算法
2022/03/18 Python