Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python检测字符串中是否包含某字符集合中的字符
May 21 Python
使用Python操作MySQL的一些基本方法
Aug 16 Python
Python中关键字nonlocal和global的声明与解析
Mar 12 Python
Python3 实现随机生成一组不重复数并按行写入文件
Apr 09 Python
Python实现聊天机器人的示例代码
Jul 09 Python
python学生信息管理系统(完整版)
Apr 05 Python
Python求两个圆的交点坐标或三个圆的交点坐标方法
Nov 07 Python
对python多线程中Lock()与RLock()锁详解
Jan 11 Python
python多线程高级锁condition简单用法示例
Nov 07 Python
pytorch ImageFolder的覆写实例
Feb 20 Python
Python通过zookeeper实现分布式服务代码解析
Jul 22 Python
python sleep和wait对比总结
Feb 03 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
模仿OSO的论坛(二)
2006/10/09 PHP
使用Apache的htaccess防止图片被盗链的解决方法
2013/04/27 PHP
php cookie中点号(句号)自动转为下划线问题
2014/10/21 PHP
PHP从尾到头打印链表实例讲解
2018/09/27 PHP
PHP实现微信退款的方法示例
2019/03/26 PHP
AngularJS 简单应用实例
2016/07/28 Javascript
微信小程序 wx:key详细介绍
2016/10/28 Javascript
详解Angular.js指令中scope类型的几种特殊情况
2017/02/21 Javascript
使用nvm管理不同版本的node与npm的方法
2017/10/31 Javascript
Vue2 SSR渲染根据不同页面修改 meta
2017/11/20 Javascript
Koa2 之文件上传下载的示例代码
2018/03/29 Javascript
vue 的点击事件获取当前点击的元素方法
2018/09/15 Javascript
详解js动态获取浏览器或页面等容器的宽高
2019/03/13 Javascript
jquery多级树形下拉菜单的实例代码
2019/07/09 jQuery
原生js添加一个或多个类名的方法分析
2019/07/30 Javascript
Element Tooltip 文字提示的使用示例
2020/07/26 Javascript
Vue 实现监听窗口关闭事件,并在窗口关闭前发送请求
2020/09/01 Javascript
Python使用正则表达式抓取网页图片的方法示例
2017/04/21 Python
python 定时修改数据库的示例代码
2018/04/08 Python
python实现机器学习之多元线性回归
2018/09/06 Python
解决python opencv无法显示图片的问题
2018/10/28 Python
Python实现字符型图片验证码识别完整过程详解
2019/05/10 Python
Python中asyncio模块的深入讲解
2019/06/10 Python
python如何以表格形式打印输出的方法示例
2019/06/21 Python
Python批量修改图片分辨率的实例代码
2019/07/04 Python
python圣诞树编写实例详解
2020/02/13 Python
美国最受欢迎的童装品牌之一:The Children’s Place
2016/07/23 全球购物
供货协议书
2014/04/22 职场文书
金融系应届毕业生求职信
2014/05/26 职场文书
先进党组织事迹材料
2014/12/26 职场文书
催款函范本大全
2015/06/24 职场文书
庆七一活动简报
2015/07/20 职场文书
新员工入职感言范文!
2019/07/04 职场文书
网络安全倡议书(3篇)
2019/09/18 职场文书
基于angular实现树形二级表格
2021/10/16 Javascript
部分武汉产收音机展览
2022/04/07 无线电