pytorch 自定义数据集加载方法


Posted in Python onAugust 18, 2019

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) ? 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size ? batch的大小(可选项,默认值为1)。

shuffle ? 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler ? 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers ? 表示读取样本的线程数, 0表示只有主线程。

collate_fn ? 合并一个样本列表称为一个batch。

pin_memory ? 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) ? 设置是否丢弃最后一个不完整的batch,默认为False。

timeout ? 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中while循环语句用法简单实例
May 07 Python
python通过apply使用元祖和列表调用函数实例
May 26 Python
python中的代码编码格式转换问题
Jun 10 Python
Python选课系统开发程序
Sep 02 Python
python使用super()出现错误解决办法
Aug 14 Python
python样条插值的实现代码
Dec 17 Python
解决pycharm工程启动卡住没反应的问题
Jan 19 Python
python 去除二维数组/二维列表中的重复行方法
Jan 23 Python
selenium+python截图不成功的解决方法
Jan 30 Python
Python zip函数打包元素实例解析
Dec 11 Python
降低python版本的操作方法
Sep 11 Python
virtualenv隔离Python环境的问题解析
Jun 21 Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
You might like
[原创]PHP中通过ADODB库实现调用Access数据库之修正版本
2006/12/31 PHP
php运行出现Call to undefined function curl_init()的解决方法
2010/11/02 PHP
使用NetBeans + Xdebug调试PHP程序的方法
2011/04/12 PHP
基于php中使用excel的简单介绍
2013/08/02 PHP
php使用filter过滤器验证邮箱 ipv6地址 url验证
2013/12/25 PHP
PHP微信支付实例解析
2016/07/22 PHP
PHP加密技术的简单实现
2016/09/04 PHP
Prototype ObjectRange对象学习
2009/07/19 Javascript
显示js对象所有属性和方法的函数
2009/10/16 Javascript
基于Asp.net与Javascript控制的日期控件
2010/05/22 Javascript
JS跨域总结
2012/08/30 Javascript
基于KMP算法JavaScript的实现方法分析
2013/05/03 Javascript
jQuery在html有效在jsp无效的原因及解决方法
2013/08/02 Javascript
jQuery的显示和隐藏方法与css隐藏的样式对比
2013/10/18 Javascript
jQuery学习笔记之jQuery构建函数的7种方法
2014/06/03 Javascript
浅谈JavaScript中的字符编码转换问题
2015/07/07 Javascript
解决jquery无法找到其他父级子集问题的方法
2016/05/10 Javascript
jQuery页面元素动态添加后绑定事件丢失方法,非 live
2016/06/16 Javascript
详解基于webpack2.x的vue2.x的多页面站点
2017/08/21 Javascript
JavaScript实现封闭区域布尔运算的示例代码
2018/06/25 Javascript
Vue 路由间跳转和新开窗口的方式(query、params)
2019/12/25 Javascript
vue使用微信扫一扫功能的实现代码
2020/04/11 Javascript
python在html中插入简单的代码并加上时间戳的方法
2018/10/16 Python
python 实现图片批量压缩的示例
2020/12/18 Python
CSS3 clip-path 用法介绍详解
2018/03/01 HTML / CSS
如何查看在weblogic中已经发布的EJB
2012/06/01 面试题
Java里面Pass by value和Pass by Reference是什么意思
2016/05/02 面试题
给儿子的表扬信
2014/01/15 职场文书
2014高考励志标语
2014/06/05 职场文书
机械电子工程专业自荐书
2014/06/10 职场文书
医院领导班子整改方案
2014/10/01 职场文书
一般基层干部群众路线教育实践活动个人对照检查材料
2014/11/04 职场文书
刘公岛导游词
2015/02/05 职场文书
求职推荐信范文
2015/03/27 职场文书
2015年教研室工作总结范文
2015/05/23 职场文书
2019安全宣传标语大全
2019/08/14 职场文书