Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python编程开发之textwrap文本样式处理技巧
Nov 13 Python
Python使用SQLite和Excel操作进行数据分析
Jan 20 Python
Python下调用Linux的Shell命令的方法
Jun 12 Python
Python中循环后使用list.append()数据被覆盖问题的解决
Jul 01 Python
Python基于OpenCV库Adaboost实现人脸识别功能详解
Aug 25 Python
Python lxml解析HTML并用xpath获取元素的方法
Jan 02 Python
python中下标和切片的使用方法解析
Aug 27 Python
Windows平台Python编程必会模块之pywin32介绍
Oct 01 Python
python 在threading中如何处理主进程和子线程的关系
Apr 25 Python
Python语言编写智力问答小游戏功能
Oct 13 Python
Python matplotlib 利用随机函数生成变化图形
Apr 26 Python
python+pyhyper实现识别图片中的车牌号思路详解
Dec 24 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
PHP 一个页面执行时间类代码
2010/03/05 PHP
析构函数与php的垃圾回收机制详解
2013/10/28 PHP
linux下实现定时执行php脚本
2015/02/13 PHP
php简单判断文本编码的方法
2015/07/30 PHP
php获取汉字拼音首字母的方法
2015/10/21 PHP
解决ThinkPHP下使用上传插件Uploadify浏览器firefox报302错误的方法
2015/12/18 PHP
php pdo操作数据库示例
2017/03/10 PHP
JavaScript 选中文字并响应获取的实现代码
2011/08/28 Javascript
node.js中的querystring.escape方法使用说明
2014/12/10 Javascript
jQuery弹出框代码封装DialogHelper
2015/01/30 Javascript
jQuery实现单击和鼠标感应事件
2015/02/01 Javascript
jQuery实现延迟跳转的方法
2015/06/05 Javascript
angular route中使用resolve在uglify压缩后问题解决
2016/09/21 Javascript
微信小程序实战之自定义模态弹窗(8)
2017/04/18 Javascript
基于BootStrap multiselect.js实现的下拉框联动效果
2017/07/28 Javascript
解决JQuery全选/反选第二次失效的问题
2017/10/11 jQuery
webpack 样式加载的实现原理
2018/06/12 Javascript
取消Bootstrap的dropdown-menu点击默认关闭事件方法
2018/08/10 Javascript
在vue.js中使用JSZip实现在前端解压文件的方法
2018/09/05 Javascript
vue favicon设置以及动态修改favicon的方法
2018/12/21 Javascript
js面向对象封装级联下拉菜单列表的实现步骤
2021/02/08 Javascript
python根据给定文件返回文件名和扩展名的方法
2015/03/27 Python
由Python运算π的值深入Python中科学计算的实现
2015/04/17 Python
Python读写文件方法总结
2015/06/09 Python
Python实现的删除重复文件或图片功能示例【去重】
2019/04/23 Python
Tensorflow读取并输出已保存模型的权重数值方式
2020/01/04 Python
Pyecharts 动态地图 geo()和map()的安装与用法详解
2020/03/25 Python
Python更换pip源方法过程解析
2020/05/19 Python
python如何遍历指定路径下所有文件(按按照时间区间检索)
2020/09/14 Python
自主招生自荐信格式
2013/12/03 职场文书
培训协议书范本
2014/04/22 职场文书
团代会宣传工作方案
2014/05/08 职场文书
2014副镇长民主生活会个人对照检查材料思想汇报
2014/09/30 职场文书
2014年小学教学工作总结
2014/11/13 职场文书
企业员工辞职信范文
2015/05/12 职场文书
《兰兰过桥》教学反思
2016/02/20 职场文书