pytorch 自定义数据集加载方法


Posted in Python onAugust 18, 2019

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) ? 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size ? batch的大小(可选项,默认值为1)。

shuffle ? 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler ? 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers ? 表示读取样本的线程数, 0表示只有主线程。

collate_fn ? 合并一个样本列表称为一个batch。

pin_memory ? 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) ? 设置是否丢弃最后一个不完整的batch,默认为False。

timeout ? 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
多线程爬虫批量下载pcgame图片url 保存为xml的实现代码
Jan 17 Python
Python3基础之函数用法
Aug 13 Python
Python如何实现文本转语音
Aug 08 Python
分析python切片原理和方法
Dec 19 Python
浅谈Django自定义模板标签template_tags的用处
Dec 20 Python
pygame游戏之旅 添加icon和bgm音效的方法
Nov 21 Python
python对csv文件追加写入列的方法
Aug 01 Python
django ManyToManyField多对多关系的实例详解
Aug 09 Python
python GUI库图形界面开发之PyQt5状态栏控件QStatusBar详细使用方法实例
Feb 28 Python
详解向scrapy中的spider传递参数的几种方法(2种)
Sep 28 Python
Python绘制K线图之可视化神器pyecharts的使用
Mar 02 Python
Python实现简单得递归下降Parser
May 02 Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
You might like
让你的WINDOWS同时支持MYSQL4,MYSQL4.1,MYSQL5X
2006/12/06 PHP
php连接与操作PostgreSQL数据库的方法
2014/12/25 PHP
PHP文件上传问题汇总(文件大小检测、大文件上传处理)
2015/12/24 PHP
jQuery使用ajaxSubmit()提交表单示例
2014/04/04 Javascript
javascript实现textarea中tab键的缩排处理方法
2015/06/26 Javascript
JavaScript实现动态删除列表框值的方法
2015/08/12 Javascript
jQuery基于ajax()使用serialize()提交form数据的方法
2015/12/08 Javascript
jquery实现界面无刷新加载登陆注册
2016/07/30 Javascript
jQuery实现将div中滚动条滚动到指定位置的方法
2016/08/10 Javascript
JS提示:Uncaught SyntaxError:Unexpected token ) 错误的解决方法
2016/08/19 Javascript
jQuery的事件预绑定
2016/12/05 Javascript
node.js中实现kindEditor图片上传功能的方法教程
2017/04/26 Javascript
将 vue 生成的 js 上传到七牛的实例
2017/07/28 Javascript
利用vue+elementUI实现部分引入组件的方法详解
2017/11/22 Javascript
基于Vuejs的搜索匹配功能实现方法
2018/03/03 Javascript
vue中v-for通过动态绑定class实现触发效果
2018/12/06 Javascript
layui table复选框禁止某几条勾选的实例
2019/09/20 Javascript
使用vue编写h5公众号跳转小程序的实现代码
2020/11/27 Vue.js
[03:47]2015国际邀请赛第三日现场精彩回顾
2015/08/08 DOTA
python生成指定长度的随机数密码
2014/01/23 Python
用python代码做configure文件
2014/07/20 Python
python模块之StringIO使用示例
2015/04/08 Python
Python数据分析:手把手教你用Pandas生成可视化图表的教程
2018/12/15 Python
Python比较配置文件的方法实例详解
2019/06/06 Python
Python Pandas中根据列的值选取多行数据
2019/07/08 Python
关于pandas的离散化,面元划分详解
2019/11/22 Python
Win系统PyQt5安装和使用教程
2019/12/25 Python
Android Q之气泡弹窗的实现示例
2020/06/23 Python
Python爬虫之Selenium下拉框处理的实现
2020/12/04 Python
分享一个页面平滑滚动小技巧(推荐)
2019/10/23 HTML / CSS
印度购物网站:TATA CLiQ
2017/11/23 全球购物
欧克利英国官网:Oakley英国
2019/08/24 全球购物
带薪年假请假条
2014/02/04 职场文书
公司中秋节活动方案
2014/02/12 职场文书
2014元旦晚会策划方案
2014/02/19 职场文书
Redis实战高并发之扣减库存项目
2022/04/14 Redis