pytorch 自定义数据集加载方法


Posted in Python onAugust 18, 2019

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) ? 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size ? batch的大小(可选项,默认值为1)。

shuffle ? 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler ? 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers ? 表示读取样本的线程数, 0表示只有主线程。

collate_fn ? 合并一个样本列表称为一个batch。

pin_memory ? 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) ? 设置是否丢弃最后一个不完整的batch,默认为False。

timeout ? 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
初步剖析C语言编程中的结构体
Jan 16 Python
Python 爬虫学习笔记之多线程爬虫
Sep 21 Python
Python实现的中国剩余定理算法示例
Aug 05 Python
Python中的is和==比较两个对象的两种方法
Sep 06 Python
详解Numpy中的数组拼接、合并操作(concatenate, append, stack, hstack, vstack, r_, c_等)
May 27 Python
Python笔记之facade模式
Nov 20 Python
Python绘图实现显示中文
Dec 04 Python
Python matplotlib画图时图例说明(legend)放到图像外侧详解
May 16 Python
keras Lambda自定义层实现数据的切片方式,Lambda传参数
Jun 11 Python
Python如何将字符串转换为日期
Jul 31 Python
python 如何实现遗传算法
Sep 22 Python
如何在 Matplotlib 中更改绘图背景的实现
Nov 26 Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
You might like
如何使用PHP获取网络上文件
2006/10/09 PHP
一步一步学习PHP(4) php 函数 补充2
2010/02/15 PHP
PHP/HTML混写的四种方式总结
2017/02/27 PHP
Referer原理与图片防盗链实现方法详解
2019/07/03 PHP
对YUI扩展的Gird组件 Part-1
2007/03/10 Javascript
Function.prototype.call.apply结合用法分析示例
2013/07/03 Javascript
jquery ajax实现下拉框三级无刷新联动,且保存保持选中值状态
2013/10/29 Javascript
jQuery scrollFix滚动定位插件
2015/04/01 Javascript
通过隐藏iframe实现无刷新上传文件操作
2016/03/16 Javascript
vue项目优化之通过keep-alive数据缓存的方法
2017/12/11 Javascript
React学习笔记之高阶组件应用
2018/06/02 Javascript
javascript数组去重方法总结(推荐)
2019/03/20 Javascript
JavaScript中EventBus实现对象之间通信
2020/10/18 Javascript
用C++封装MySQL的API的教程
2015/05/06 Python
PyQt5每天必学之进度条效果
2018/04/19 Python
python 列表,数组和矩阵sum的用法及区别介绍
2018/06/28 Python
Python实现基于socket的udp传输与接收功能详解
2019/11/15 Python
python全局变量引用与修改过程解析
2020/01/07 Python
Python使用jpype模块调用jar包过程解析
2020/07/29 Python
Pandas直接读取sql脚本的方法
2021/01/21 Python
纯CSS3+DIV实现小三角形边框效果的示例代码
2020/08/03 HTML / CSS
用HTML5制作烟火效果的教程
2015/05/12 HTML / CSS
新百伦折扣店:Joe’s New Balance Outlet
2016/08/20 全球购物
俄语地区最大的中国商品在线购物网站之一:Umka Mall
2019/11/03 全球购物
垃圾回收的优点和原理。并考虑2种回收机制
2016/10/16 面试题
程序员跳槽必看面试题总结
2013/06/28 面试题
学院书画协会部门岗位职责
2013/12/01 职场文书
办公室综合文员岗位职责范本
2014/02/13 职场文书
王老吉广告词
2014/03/20 职场文书
入党积极分子学习优秀共产党员先进事迹思想汇报
2014/09/13 职场文书
教师个人查摆剖析材料
2014/10/14 职场文书
求职简历自我评价怎么写
2015/03/10 职场文书
总经理司机岗位职责
2015/04/10 职场文书
离婚起诉书怎么写
2015/05/19 职场文书
【海涛教你打DOTA】黑鸟第一视角解说
2022/04/01 DOTA
Mybatis-Plus 使用 @TableField 自动填充日期
2022/04/26 Java/Android