pytorch 自定义数据集加载方法


Posted in Python onAugust 18, 2019

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) ? 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size ? batch的大小(可选项,默认值为1)。

shuffle ? 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler ? 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers ? 表示读取样本的线程数, 0表示只有主线程。

collate_fn ? 合并一个样本列表称为一个batch。

pin_memory ? 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) ? 设置是否丢弃最后一个不完整的batch,默认为False。

timeout ? 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(五):socket的一些补充
Jun 09 Python
python通过floor函数舍弃小数位的方法
Mar 17 Python
Python实现简单的文件传输与MySQL备份的脚本分享
Jan 03 Python
PHP网页抓取之抓取百度贴吧邮箱数据代码分享
Apr 13 Python
python实现读取excel写入mysql的小工具详解
Nov 20 Python
基于windows下pip安装python模块时报错总结
Jun 12 Python
Python Web程序搭建简单的Web服务器
Jul 31 Python
python基于property()函数定义属性
Jan 22 Python
Python中if有多个条件处理方法
Feb 26 Python
Django Path转换器自定义及正则代码实例
May 29 Python
python中remove函数的踩坑记录
Jan 04 Python
python tkinter模块的简单使用
Apr 07 Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
You might like
星际争霸教主Flash的ID由来:你永远不会知道他之前的ID是www!
2019/01/18 星际争霸
杏林同学录(八)
2006/10/09 PHP
PHP 观察者模式的实现代码
2013/05/10 PHP
PHP实现生成唯一编号(36进制的不重复编号)
2014/07/01 PHP
Linux下PHP安装mcrypt扩展模块笔记
2014/09/10 PHP
php判断是否为ajax请求的方法
2016/11/29 PHP
利用PHP实现一个简单的用户登记表示例
2017/04/25 PHP
Laravel5.5 数据库迁移:创建表与修改表示例
2019/10/23 PHP
javascript之dhDataGrid Ver2.0.0代码
2007/07/01 Javascript
javascript 对表格的行和列都能加亮显示
2008/12/26 Javascript
jquery 实现二级/三级/多级联动菜单的思路及代码
2013/04/08 Javascript
javascript模拟map输出与去除重复项的方法
2015/02/09 Javascript
每天一篇javascript学习小结(面向对象编程)
2015/11/20 Javascript
input点击后placeholder中的提示消息消失
2016/01/15 Javascript
JavaScript中的Array 对象(数组对象)
2016/06/02 Javascript
JavaScript实现的微信二维码图片生成器的示例
2016/10/26 Javascript
仿iframe效果Aajx文件上传实例
2016/11/18 Javascript
bootstrap下拉菜单使用方法解析
2017/01/13 Javascript
jQuery插件HighCharts实现的2D堆条状图效果示例【附demo源码下载】
2017/03/14 Javascript
JavaScript对象_动力节点Java学院整理
2017/06/23 Javascript
Python功能键的读取方法
2015/05/28 Python
Python中for循环和while循环的基本使用方法
2015/08/21 Python
快速了解Python相对导入
2018/01/12 Python
python3.4.3下逐行读入txt文本并去重的方法
2018/04/29 Python
Python常用爬虫代码总结方便查询
2019/02/25 Python
基于Python实现用户管理系统
2019/02/26 Python
Python importlib动态导入模块实现代码
2020/04/16 Python
Python SQLAlchemy库的使用方法
2020/10/13 Python
使用css3实现超炫的loading加载动画效果
2014/05/07 HTML / CSS
汽车专业学生自我评价
2014/01/19 职场文书
详细的大学生创业计划书模板
2014/01/27 职场文书
学生会离职感言
2014/02/11 职场文书
学员自我鉴定
2014/03/19 职场文书
运动会400米加油稿(8篇)
2014/09/22 职场文书
Go语言基础知识点介绍
2021/07/04 Golang
海康机器人重磅发布全新算法开发平台VM4.2
2022/04/21 数码科技