解决Pytorch dataloader时报错每个tensor维度不一样的问题


Posted in Python onMay 28, 2021

使用pytorch的dataloader报错:

RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1

1. 问题描述

报错定位:位于定义dataset的代码中

def __getitem__(self, index):
 ...
 return y    #此处报错

报错内容

File "D:\python\lib\site-packages\torch\utils\data\_utils\collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1

把前一行的报错带上能够更清楚地明白问题在哪里.

2.问题分析

从报错可以看到,是在代码中执行torch.stack时发生了报错.因此必须要明白在哪里执行了stack操作.

通过调试可以发现,在通过loader加载一个batch数据的时候,是通过每一次给一个随机的index取出相应的向量.那么最终要形成一个batch的数据就必须要进行拼接操作,而torch.stack就是进行这里所说的拼接.

再来看看具体报的什么错: 说是stack的向量维度不同. 这说明在每次给出一个随机的index,返回的y向量的维度应该是相同的,而我们这里是不同的.

这样解决方法也就明确了:使返回的向量y的维度固定下来.

3.问题出处

为什么我会出现这样的一个问题,是因为我的特征向量中存在multi-hot特征.而为了节省空间,我是用一个列表存储这个特征的.示例如下:

feature=[[1,3,5],
  [0,2],
  [1,2,5,8]]

这就导致了我每次返回的向量的维度是不同的.因此可以采用向量补全的方法,把不同长度的向量补全成等长的.

# 把所有向量的长度都补为6
 multi = np.pad(multi, (0, 6-multi.shape[0]), 'constant', constant_values=(0, -1))

4.总结

在构建dataset重写的__getitem__方法中要返回相同长度的tensor.

可以使用向量补全的方法来解决这个问题.

补充:pytorch学习笔记:torch.utils.data下的TensorDataset和DataLoader的使用

一、TensorDataset

对给定的tensor数据(样本和标签),将它们包装成dataset。注意,如果是numpy的array,或者Pandas的DataFrame需要先转换成Tensor。

'''
data_tensor (Tensor) - 样本数据
target_tensor (Tensor) - 样本目标(标签)
'''
 dataset=torch.utils.data.TensorDataset(data_tensor, 
                                        target_tensor)

下面举个例子:

我们先定义一下样本数据和标签数据,一共有1000个样本

import torch
import numpy as np
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = torch.tensor(np.random.normal(0, 1, 
                       (num_examples, num_inputs)), 
                       dtype=torch.float)

labels = true_w[0] * features[:, 0] + \
         true_w[1] * features[:, 1] + true_b

labels += torch.tensor(np.random.normal(0, 0.01, 
                       size=labels.size()), 
                       dtype=torch.float)

print(features.shape)
print(labels.shape)

'''
输出:torch.Size([1000, 2])
     torch.Size([1000])
'''

然后我们使用TensorDataset来生成数据集

import torch.utils.data as Data
# 将训练数据的特征和标签组合
dataset = Data.TensorDataset(features, labels)

二、DataLoader

数据加载器,组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。它可以对我们上面所说的数据集Dataset作进一步的设置。

dataset (Dataset) ? 加载数据的数据集。

batch_size (int, optional) ? 每个batch加载多少个样本(默认: 1)。

shuffle (bool, optional) ? 设置为True时会在每个epoch重新打乱数据(默认: False).

sampler (Sampler, optional) ? 定义从数据集中提取样本的策略。如果指定,则shuffle必须设置成False。

num_workers (int, optional) ? 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)

pin_memory:内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

drop_last (bool, optional) ? 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

timeout:是用来设置数据读取的超时时间的,如果超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

data_iter=torch.utils.data.DataLoader(dataset, batch_size=1, 
                            shuffle=False, sampler=None, 
                            batch_sampler=None, num_workers=0, 
                            collate_fn=None, pin_memory=False, 
                            drop_last=False, timeout=0, 
                            worker_init_fn=None, 
                            multiprocessing_context=None)

上面对一些重要常用的参数做了说明,其中有一个参数是sampler,下面我们对它有哪些具体取值再做一下说明。只列出几个常用的取值:

torch.utils.data.sampler.SequentialSampler(dataset)

样本元素按顺序采样,始终以相同的顺序。

torch.utils.data.sampler.RandomSampler(dataset)

样本元素随机采样,没有替换。

torch.utils.data.sampler.SubsetRandomSampler(indices)

样本元素从指定的索引列表中随机抽取,没有替换。

下面就来看一个例子,该例子使用的dataset就是上面所生成的dataset

data_iter=Data.DataLoader(dataset, 
                          batch_size=10, 
                          shuffle=False,
sampler=torch.utils.data.sampler.RandomSampler(dataset))

for X, y in data_iter:
    print(X,"\n", y)
    break

'''
输出:
tensor([[-1.6338,  0.8451],
        [ 0.7245, -0.7387],
        [ 0.4672,  0.2623],
        [-1.9082,  0.0980],
        [-0.3881,  0.5138],
        [-0.6983, -0.4712],
        [ 0.1400,  0.7489],
        [-0.7761, -0.4596],
        [-2.2700, -0.2532],
        [-1.2641, -2.8089]]) 

tensor([-1.9451,  8.1587,  4.2374,  0.0519,  1.6843,  4.3970,  
        1.9311,  4.1999,0.5253, 11.2277])
'''

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的tuple元组详细介绍
Feb 02 Python
Win10下Python环境搭建与配置教程
Nov 18 Python
使用Python写一个贪吃蛇游戏实例代码
Aug 21 Python
python GUI实例学习
Nov 21 Python
python自动登录12306并自动点击验证码完成登录的实现源代码
Apr 25 Python
Python3 中把txt数据文件读入到矩阵中的方法
Apr 27 Python
pygame游戏之旅 载入小车图片、更新窗口
Nov 20 Python
浅谈Pandas Series 和 Numpy array中的相同点
Jun 28 Python
Python.append()与Python.expand()用法详解
Dec 18 Python
python openCV实现摄像头获取人脸图片
Aug 20 Python
如何用python识别滑块验证码中的缺口
Apr 01 Python
Pandas实现批量拆分与合并Excel的示例代码
May 30 Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
You might like
php获取url字符串截取路径的文件名和扩展名的函数
2010/01/22 PHP
同台服务器使用缓存APC效率高于Memcached的演示代码
2010/02/16 PHP
php生成excel列名超过26列大于Z时的解决方法
2014/12/29 PHP
PHP实现微信商户支付企业付款到零钱功能
2018/09/30 PHP
基于jQuery的星级评分插件
2011/08/12 Javascript
不同的jQuery API来处理不同的浏览器事件
2012/12/09 Javascript
前端开发过程中浏览器版本的两种判定方法
2013/10/30 Javascript
12306 刷票脚本及稳固刷票脚本(防挂)
2017/01/04 Javascript
Angular2.0实现modal对话框的方法示例
2018/02/18 Javascript
vue路由前进后退动画效果的实现代码
2018/12/10 Javascript
vue实现form表单与table表格的数据关联功能示例
2019/01/29 Javascript
使用jquery的cookie实现登录页记住用户名和密码的方法
2019/03/13 jQuery
js控制随机数生成概率代码实例
2019/03/21 Javascript
微信小程序 setData 对 data数据影响问题
2019/04/18 Javascript
js打开word文档预览操作示例【不是下载】
2019/05/23 Javascript
vue 源码解析之虚拟Dom-render
2019/08/26 Javascript
对layui中table组件工具栏的使用详解
2019/09/19 Javascript
Vue项目中使用flow做类型检测的方法
2020/03/18 Javascript
详解ES6数组方法find()、findIndex()的总结
2020/05/12 Javascript
详解微信小程序入门从这里出发(登录注册、开发工具、文件及结构介绍)
2020/07/21 Javascript
[50:02]完美世界DOTA2联赛循环赛 Magma vs IO BO2第一场 11.01
2020/11/02 DOTA
python使用cPickle模块序列化实例
2014/09/25 Python
浅谈tensorflow中几个随机函数的用法
2018/07/27 Python
django开发post接口简单案例,获取参数值的方法
2018/12/11 Python
python 用下标截取字符串的实例
2018/12/25 Python
说说如何遍历Python列表的方法示例
2019/02/11 Python
django ORM之values和annotate使用详解
2020/05/19 Python
python字符串的index和find的区别详解
2020/06/20 Python
python用Tkinter做自己的中文代码编辑器
2020/09/07 Python
python工具——Mimesis的简单使用教程
2021/01/16 Python
加拿大健康、婴儿和美容产品在线购物:Well.ca
2016/11/30 全球购物
教育科学研究生自荐信
2013/10/09 职场文书
大学生入党自荐书
2015/03/05 职场文书
抢劫罪辩护词
2015/05/21 职场文书
2019年干货:自我鉴定
2019/03/25 职场文书
Java循环队列与非循环队列的区别总结
2021/06/22 Java/Android