Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python端口扫描系统实现方法
Nov 19 Python
python高手之路python处理excel文件(方法汇总)
Jan 07 Python
Python实现定时任务
Feb 08 Python
Django 添加静态文件的两种实现方法(必看篇)
Jul 14 Python
python使用锁访问共享变量实例解析
Feb 08 Python
Python堆排序原理与实现方法详解
May 11 Python
python skimage 连通性区域检测方法
Jun 21 Python
python使用opencv驱动摄像头的方法
Aug 03 Python
python添加模块搜索路径和包的导入方法
Jan 19 Python
Python字符串通过'+'和join函数拼接新字符串的性能测试比较
Mar 05 Python
Python 3 判断2个字典相同
Aug 06 Python
Python使用docx模块实现刷题功能代码
Feb 13 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
广播爱好者需要了解的天线知识
2021/03/01 无线电
浅析PHP substr,mb_substr以及mb_strcut的区别和用法
2013/06/21 PHP
PHP保存带BOM文件的方法
2015/02/12 PHP
使用php-timeit估计php函数的执行时间
2015/09/06 PHP
Zend Framework教程之Loader以及PluginLoader用法详解
2016/03/09 PHP
Laravel路由设定和子路由设定实例分析
2016/03/30 PHP
php简单实现sql防注入的方法
2016/04/22 PHP
Smarty简单生成表单元素的方法示例
2016/05/23 PHP
PHP编译configure时常见错误的总结
2017/08/17 PHP
PHP实现通过文本文件统计页面访问量功能示例
2019/02/13 PHP
JavaScript入门教程(8) Location地址对象
2009/01/31 Javascript
浅谈js和css内联外联注意事项
2016/06/30 Javascript
js 将图片连接转换成base64格式的简单实例
2016/08/10 Javascript
Node.js中看JavaScript的引用
2017/04/22 Javascript
AngularJs实现聊天列表实时刷新功能
2017/06/15 Javascript
vue2.0中click点击当前li实现动态切换class
2017/06/21 Javascript
基于JS实现html中placeholder属性提示文字效果示例
2018/04/19 Javascript
Vue二次封装axios为插件使用详解
2018/05/21 Javascript
vue.js内置组件之keep-alive组件使用
2018/07/10 Javascript
vue 点击按钮实现动态挂载子组件的方法
2018/09/07 Javascript
在layui中使用form表单监听ajax异步验证注册的实例
2019/09/03 Javascript
Vue实现base64编码图片间的切换功能
2019/12/04 Javascript
解决vue数据不实时更新的问题(数据更改了,但数据不实时更新)
2020/10/27 Javascript
vant-ui框架的一个bug(解决切换后onload不触发)
2020/11/11 Javascript
angular *Ngif else用法详解
2020/12/15 Javascript
python保存字符串到文件的方法
2015/07/01 Python
值得收藏,Python 开发中的高级技巧
2018/11/23 Python
详解Python中is和==的区别
2019/03/21 Python
python查询MySQL将数据写入Excel
2020/10/29 Python
深入浅析HTML5中的SVG
2015/11/27 HTML / CSS
M.M.LaFleur官网:美国职业女装品牌
2020/10/27 全球购物
工程管理专业毕业生自荐信
2014/01/24 职场文书
法定代表人授权委托书
2014/04/04 职场文书
开学寄语大全
2014/04/08 职场文书
新郎新娘致辞
2015/07/31 职场文书
Redis全局ID生成器的实现
2022/06/05 Redis