pytorch 自定义数据集加载方法


Posted in Python onAugust 18, 2019

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) ? 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size ? batch的大小(可选项,默认值为1)。

shuffle ? 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler ? 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers ? 表示读取样本的线程数, 0表示只有主线程。

collate_fn ? 合并一个样本列表称为一个batch。

pin_memory ? 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) ? 设置是否丢弃最后一个不完整的batch,默认为False。

timeout ? 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python下函数参数的传递(参数带星号的说明)
Sep 19 Python
仅用50行代码实现一个Python编写的计算器的教程
Apr 17 Python
六个窍门助你提高Python运行效率
Jun 09 Python
Python自动化运维和部署项目工具Fabric使用实例
Sep 18 Python
浅谈Python实现2种文件复制的方法
Jan 19 Python
使用pandas模块读取csv文件和excel表格,并用matplotlib画图的方法
Jun 22 Python
在Mac上删除自己安装的Python方法
Oct 29 Python
Python实现二维曲线拟合的方法
Dec 29 Python
详解Django中CBV(Class Base Views)模型源码分析
Feb 25 Python
python的set处理二维数组转一维数组的方法示例
May 31 Python
Python 元组拆包示例(Tuple Unpacking)
Dec 24 Python
windows下的pycharm安装及其设置中文菜单
Apr 23 Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
You might like
地球防卫队:陪着奥特曼打小怪兽的人类力量 那些经典队服
2020/03/08 日漫
新手配置 PHP 调试环境(IIS+PHP+MYSQL)
2007/01/10 PHP
php网站来路获取代码(针对搜索引擎)
2010/06/08 PHP
深入理解PHP原理之异常机制
2010/08/21 PHP
php更新mysql后获取影响的行数发生异常解决方法
2013/03/28 PHP
PHP的引用详解
2015/02/22 PHP
PHP基于递归算法解决兔子生兔子问题
2018/05/11 PHP
PHP rsa加密解密算法原理解析
2020/12/09 PHP
打造基于jQuery的高性能TreeView(asp.net)
2011/02/23 Javascript
10款非常有用的 Ajax 插件分享
2012/03/14 Javascript
JSON为什么那样红为什么要用json(另有洞天)
2012/12/26 Javascript
jQuery去掉字符串起始和结尾的空格(多种方法实现)
2013/04/01 Javascript
javaScript arguments 对象使用介绍
2013/10/18 Javascript
Javascript的&amp;&amp;和||的另类用法
2014/07/23 Javascript
Vue数据驱动模拟实现3
2017/01/11 Javascript
js仿淘宝商品放大预览功能
2017/03/15 Javascript
layui框架中layer父子页面交互的方法分析
2017/11/15 Javascript
详解AngularJS之$window窗口对象
2018/01/17 Javascript
React Native 自定义下拉刷新上拉加载的列表的示例
2018/03/01 Javascript
node中的session的具体使用
2018/09/14 Javascript
vue.draggable实现表格拖拽排序效果
2018/12/01 Javascript
JavaScript函数Call、Apply原理实例解析
2020/02/17 Javascript
[54:02]2018DOTA2亚洲邀请赛 4.1 小组赛 B组 IG vs VGJ.T
2018/04/03 DOTA
Pandas之MultiIndex对象的示例详解
2019/06/25 Python
Java文件与类动手动脑实例详解
2019/11/10 Python
Python更新所有已安装包的操作
2020/02/13 Python
英国旅游额外服务市场领导者:Holiday Extras(机场停车场、酒店、接送等)
2017/10/07 全球购物
澳大利亚最便宜的网上药房:Chemist Warehouse
2020/01/30 全球购物
作为网站管理者应当如何防范XSS
2014/08/16 面试题
大学校园毕业自我鉴定
2014/01/15 职场文书
客服部班长工作责任制
2014/02/25 职场文书
节约用水标语
2014/06/11 职场文书
2014年维修工作总结
2014/11/22 职场文书
功夫熊猫观后感
2015/06/10 职场文书
教师节晚会主持词
2015/06/30 职场文书
利用JuiceFS使MySQL 备份验证性能提升 10 倍
2022/03/17 MySQL