pytorch 自定义数据集加载方法


Posted in Python onAugust 18, 2019

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) ? 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size ? batch的大小(可选项,默认值为1)。

shuffle ? 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler ? 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers ? 表示读取样本的线程数, 0表示只有主线程。

collate_fn ? 合并一个样本列表称为一个batch。

pin_memory ? 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) ? 设置是否丢弃最后一个不完整的batch,默认为False。

timeout ? 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python获取CPU、内存和硬盘等windowns系统信息的2个例子
Apr 15 Python
Python socket C/S结构的聊天室应用实现
Nov 30 Python
python比较两个列表大小的方法
Jul 11 Python
python实现web方式logview的方法
Aug 10 Python
django 外键model的互相读取方法
Dec 15 Python
解决python中画图时x,y轴名称出现中文乱码的问题
Jan 29 Python
如何使用Python实现斐波那契数列
Jul 02 Python
django之对FileField字段的upload_to的设定方法
Jul 28 Python
Python 50行爬虫抓取并处理图灵书目过程详解
Sep 20 Python
详解Python3定时器任务代码
Sep 23 Python
解决pycharm启动后总是不停的updating indices...indexing的问题
Nov 27 Python
Python通过递归函数输出嵌套列表元素
Oct 15 Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
You might like
dedecms后台验证码总提示错误的解决方法
2007/03/21 PHP
初次接触php抽象工厂模式(Elgg)
2010/03/21 PHP
PHP+apc+ajax实现的ajax_upload上传进度条代码
2016/01/25 PHP
PHP中获取文件创建日期、修改日期、访问时间的方法
2016/11/05 PHP
jQuery 中关于CSS操作部分使用说明
2007/06/10 Javascript
整理一些JavaScript的IE和火狐的兼容性注意事项
2011/03/17 Javascript
jquery.bgiframe.js在IE9下提示INVALID_CHARACTER_ERR错误
2013/01/11 Javascript
jquery动态增加删除表格行的小例子
2013/11/14 Javascript
JavaScript中for循环的使用详解
2015/06/03 Javascript
jQuery遮罩层效果实例分析
2016/01/14 Javascript
基于jQuery实现数字滚动效果
2017/01/16 Javascript
React Native时间转换格式工具类分享
2017/10/24 Javascript
微信小程序实现定位及到指定位置导航的示例代码
2019/08/20 Javascript
基于ts的动态接口数据配置的详解
2019/12/18 Javascript
浅谈vue生命周期共有几个阶段?分别是什么?
2020/08/07 Javascript
[02:27]刀塔重生降临
2015/10/14 DOTA
[47:21]Liquid vs TNC Supermajor 胜者组 BO3 第一场 6.4
2018/06/05 DOTA
浅谈Python 集合(set)类型的操作——并交差
2016/06/30 Python
浅析python递归函数和河内塔问题
2017/04/18 Python
Python 线程池用法简单示例
2019/10/02 Python
Python程序暂停的正常处理方法
2019/11/07 Python
如何使用scrapy中的ItemLoader提取数据
2020/09/30 Python
CSS3中的opacity属性使用教程
2015/08/19 HTML / CSS
HTML5调用手机发短信和打电话功能
2020/04/29 HTML / CSS
德国最大的婴儿用品网上商店:Kidsroom.de(支持中文)
2020/09/02 全球购物
香港艺人陈冠希创办的潮流品牌:JUICESTORE
2021/03/04 全球购物
Linux如何压缩可执行文件
2013/10/21 面试题
会计专业导师推荐信
2014/03/08 职场文书
金融事务专业求职信
2014/04/25 职场文书
《三亚落日》教学反思
2014/04/26 职场文书
教师工作自我鉴定范文
2014/09/14 职场文书
公务员四风问题对照检查材料整改措施
2014/09/26 职场文书
汽车4S店销售经理岗位职责
2015/04/02 职场文书
运动会口号霸气押韵
2015/12/24 职场文书
一次线上mongo慢查询问题排查处理记录
2022/03/18 MongoDB
Spring Boot项目传参校验的最佳实践指南
2022/04/05 Java/Android