Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python基于windows平台锁定键盘输入的方法
Mar 05 Python
R vs. Python 数据分析中谁与争锋?
Oct 18 Python
Python实现自动发送邮件功能
Mar 02 Python
Python实现识别手写数字大纲
Jan 29 Python
TensorFlow实现创建分类器
Feb 06 Python
python语音识别实践之百度语音API
Aug 30 Python
关于Python 的简单栅格图像边界提取方法
Jul 05 Python
Python OrderedDict的使用案例解析
Oct 25 Python
Python中__repr__和__str__区别详解
Nov 07 Python
Python实现投影法分割图像示例(一)
Jan 17 Python
pycharm如何使用anaconda中的各种包(操作步骤)
Jul 31 Python
还在手动盖楼抽奖?教你用Python实现自动评论盖楼抽奖(一)
Jun 07 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
PHP var_dump遍历对象属性的函数与应用代码
2010/06/04 PHP
微信公众号开发之语音消息识别php代码
2016/08/08 PHP
curl 出现错误的调试方法(必看)
2017/02/13 PHP
PHP弱类型语言中类型判断操作实例详解
2017/08/10 PHP
php curl批处理实现可控并发异步操作示例
2018/05/09 PHP
php中file_get_contents()函数用法实例
2019/02/21 PHP
菜鸟javascript基础资料整理2
2010/12/06 Javascript
商城常用滚动的焦点图效果代码简单实用
2013/03/28 Javascript
jquery属性选择器not has怎么写 行悬停高亮显示
2013/11/13 Javascript
结合JQ1.9通过js正则判断各种浏览器版本的方法
2013/12/30 Javascript
flash遮住div问题的正确解决方法
2014/02/27 Javascript
js图片轮播手动切换特效
2017/01/12 Javascript
详解微信小程序开发之formId使用(模板消息)
2019/08/27 Javascript
[15:56]Heroes18_暗影萨满(完美)
2014/10/31 DOTA
python传递参数方式小结
2015/04/17 Python
在Python中编写数据库模块的教程
2015/04/29 Python
Python捕捉和模拟鼠标事件的方法
2015/06/03 Python
python 接口返回的json字符串实例
2018/03/27 Python
Python hashlib模块用法实例分析
2018/06/12 Python
python利用tkinter实现屏保
2019/07/30 Python
Python安装tar.gz格式文件方法详解
2020/01/19 Python
Python类如何定义私有变量
2020/02/03 Python
Python matplotlib 绘制双Y轴曲线图的示例代码
2020/06/12 Python
如何利用python进行时间序列分析
2020/08/04 Python
英国领先的维生素和营养补充剂直接供应商:Healthspan
2019/04/22 全球购物
德国最新街头服饰网上商店:BODYCHECK
2019/09/15 全球购物
试用期自我鉴定范文
2014/03/20 职场文书
房屋出售协议书
2014/04/10 职场文书
作风整顿个人剖析材料
2014/10/06 职场文书
群众路线自我剖析范文
2014/11/04 职场文书
2014年机关党建工作总结
2014/11/11 职场文书
Nginx域名转发https访问的实现
2021/03/31 Servers
go设置多个GOPATH的方式
2021/05/05 Golang
python分分钟绘制精美地图海报
2022/02/15 Python
全新239军机修复记
2022/04/05 无线电
分析SQL窗口函数之排名窗口函数
2022/04/21 Oracle