Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现sublime3的less编译插件示例
Apr 27 Python
python实现的登录和操作开心网脚本分享
Jul 09 Python
Python中的高级数据结构详解
Mar 27 Python
python 匹配url中是否存在IP地址的方法
Jun 04 Python
Numpy之文件存取的示例代码
Aug 03 Python
python ---lambda匿名函数介绍
Mar 13 Python
numpy.where() 用法详解
May 27 Python
Django多数据库的实现过程详解
Aug 01 Python
Python检查 云备份进程是否正常运行代码实例
Aug 22 Python
Python编写一个验证码图片数据标注GUI程序附源码
Dec 09 Python
python内置模块collections知识点总结
Dec 19 Python
Django DRF路由与扩展功能的实现
Jun 03 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
PHP json格式和js json格式 js跨域调用实现代码
2012/09/08 PHP
php实现HTML实体编号与非ASCII字符串相互转换类实例
2016/11/02 PHP
js apply/call/caller/callee/bind使用方法与区别分析
2009/10/28 Javascript
javascript getElementsByClassName实现代码
2010/10/11 Javascript
通过JavaScript使Div居中并随网页大小改变而改变
2013/06/24 Javascript
JSON取值前判断
2014/12/23 Javascript
jQuery中:has选择器用法实例
2014/12/30 Javascript
readonly和disabled属性的区别
2015/07/26 Javascript
JS+CSS实现下拉列表框美化效果(3款)
2015/08/15 Javascript
如何解决谷歌浏览器下jquery无法获取图片的尺寸
2015/09/10 Javascript
详细总结Javascript中的焦点管理
2016/09/17 Javascript
angular forEach方法遍历源码解读
2017/01/25 Javascript
解析NodeJS异步I/O的实现
2017/04/13 NodeJs
jQuery中的deferred对象和extend方法详解
2017/05/08 jQuery
JavaScript取得gridview中获取checkbox选中的值
2017/07/24 Javascript
支付宝小程序tabbar底部导航
2018/11/06 Javascript
nodejs实现聊天机器人功能
2019/09/19 NodeJs
JavaScript/TypeScript 实现并发请求控制的示例代码
2021/01/18 Javascript
python安装与使用redis的方法
2016/04/19 Python
python线程、进程和协程详解
2016/07/19 Python
python统计文章中单词出现次数实例
2020/02/27 Python
美国Randolph太阳镜官网:美国制造的飞行员太阳镜和射击眼镜
2018/06/15 全球购物
Priority Pass机场贵宾室会籍计划:全球超过1200间机场贵宾室
2018/08/26 全球购物
LivingSocial英国:英国本地优惠
2019/02/22 全球购物
有影响力的品牌之家:Our Social Collective
2019/06/08 全球购物
Chinti & Parker官网:奢华羊绒女装和创新针织设计
2021/01/01 全球购物
商场中秋节广播稿
2014/01/17 职场文书
岗位说明书范文
2014/05/07 职场文书
安全生产标语
2014/06/06 职场文书
党员教师个人对照检查材料范文
2014/09/25 职场文书
离婚财产分配协议书
2014/10/21 职场文书
党的群众路线教育实践活动先进个人材料
2014/12/24 职场文书
联村联户简报
2015/07/21 职场文书
python3 hdf5文件 遍历代码
2021/05/19 Python
基于angular实现树形二级表格
2021/10/16 Javascript
SQL Server删除表中的重复数据
2022/05/25 SQL Server