Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python控制台英汉汉英电子词典
Apr 23 Python
Python的Flask框架中配置多个子域名的方法讲解
Jun 07 Python
Python模糊查询本地文件夹去除文件后缀的实例(7行代码)
Nov 09 Python
Python+Turtle动态绘制一棵树实例分享
Jan 16 Python
opencv python 基于KNN的手写体识别的实例
Aug 03 Python
Python3.5面向对象编程图文与实例详解
Apr 24 Python
python中通过selenium简单操作及元素定位知识点总结
Sep 10 Python
python类中super() 的使用解析
Dec 19 Python
tensorflow的计算图总结
Jan 12 Python
Python操作Sqlite正确实现方法解析
Feb 05 Python
python进度条显示之tqmd模块
Aug 22 Python
协程Python 中实现多任务耗资源最小的方式
Oct 19 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
PHP Session_Regenerate_ID函数双释放内存破坏漏洞
2011/01/27 PHP
php获取远程图片的两种 CURL方式和sockets方式获取远程图片
2011/11/07 PHP
PHP 观察者模式的实现代码
2013/05/10 PHP
PHP命名空间namespace的定义方法详解
2017/03/29 PHP
tp5.1 框架数据库-数据集操作实例分析
2020/05/26 PHP
jquery select动态加载选择(兼容各种浏览器)
2013/02/01 Javascript
同时使用n个window onload加载实例介绍
2013/04/25 Javascript
JS执行删除前的判断代码
2014/02/18 Javascript
Bootstrap时间选择器datetimepicker和daterangepicker使用实例解析
2016/09/17 Javascript
纯原生js实现table表格的增删
2017/01/05 Javascript
Bootstrap警告(Alerts)的实现方法
2017/03/22 Javascript
JS禁止浏览器右键查看元素或按F12审查元素自动关闭页面示例代码
2017/09/07 Javascript
基于vue-cli创建的项目的目录结构及说明介绍
2017/11/23 Javascript
微信小程序使用modal组件弹出对话框功能示例
2017/11/29 Javascript
vue 实现左右拖拽元素并且不超过他的父元素的宽度
2018/11/30 Javascript
Angular PWA使用的Demo示例
2019/01/31 Javascript
vue 组件简介
2020/07/31 Javascript
原生js+css实现tab切换功能
2020/09/17 Javascript
Openlayers3实现车辆轨迹回放功能
2020/09/29 Javascript
使用相同的Apache实例来运行Django和Media文件
2015/07/22 Python
python 递归遍历文件夹,并打印满足条件的文件路径实例
2017/08/30 Python
Python反射用法实例简析
2017/12/22 Python
Django命名URL和反向解析URL实现解析
2019/08/09 Python
Python 如何优雅的将数字转化为时间格式的方法
2019/09/26 Python
利用django创建一个简易的博客网站的示例
2020/09/29 Python
GWebs公司笔试题
2012/05/04 面试题
超市业务员岗位职责
2013/12/05 职场文书
函授毕业自我鉴定
2013/12/19 职场文书
幼儿园春游活动方案
2014/01/19 职场文书
计算机数据库专业职业生涯规划书
2014/02/08 职场文书
给老师的检讨书
2014/02/11 职场文书
《口技》教学反思
2014/02/21 职场文书
业务部门经理岗位职责
2014/02/23 职场文书
银行职员工作失误检讨书
2014/10/14 职场文书
2014年班干部工作总结
2014/11/25 职场文书
《颐和园》教学反思
2016/02/19 职场文书