解决pytorch读取自制数据集出现过的问题


Posted in Python onMay 31, 2021

问题1

问题描述:

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

解决方式

数据格式不对, 把image转成tensor,参数transform进行如下设置就可以了:transform=transform.ToTensor()。注意检测一下transform

问题2

问题描述:

TypeError: append() takes exactly one argument (2 given)

出现问题的地方

imgs.append(words[0], int(words[1]))

解决方式

加括号,如下

imgs.append((words[0], int(words[1])))

问题3

问题描述

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

解决方式

数据和模型不在同一设备上,应该要么都在GPU运行,要么都在CPU

问题4

问题描述

RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[1, 3, 512, 512] to have 1 channels, but got 3 channels instead

解决方式

图像竟然是RGB,但我的训练图像是一通道的灰度图,所以得想办法把 mode 转换成灰度图L

补充:神经网络 pytorch 数据集读取(自动读取数据集,手动读取自己的数据)

对于pytorch,我们有现成的包装好的数据集可以使用,也可以自己创建自己的数据集,大致来说有三种方法,这其中用到的两个包是datasets和DataLoader

datasets:用于将数据和标签打包成数据集

DataLoader:用于对数据集的高级处理,比如分组,打乱,处理等,在训练和测试中可以直接使用DataLoader进行处理

第一种 现成的打包数据集

这种比较简答,只需要现成的几行代码和一个路径就可以完成,但是一般都是常用比如cifar-10

解决pytorch读取自制数据集出现过的问题

对于常用数据集,可以使用torchvision.datasets直接进行读取,这是对其常用的处理,该类也是继承于torch.utils.data.Dataset。

#是第一次运行的话会下载数据集 现成的话可以使用root参数指定数据集位置
# 存放的格式如下图
 
# 根据接口读取默认的CIFAR10数据 进行训练和测试
#预处理
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#读取数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
#打包成DataLoader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=1)
 
#同上
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=1)
classes = (1,2,3,4,5,6,7,8,9,10)  #类别定义
 
#使用
 for epoch in range(3):
        running_loss = 0.0 #清空loss
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data #trainloader返回:id,image,labels
 
            # 将inputs与labels装进Variable中
            inputs, labels = Variable(inputs), Variable(labels)
            
            #使用print代替输出
            print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

解决pytorch读取自制数据集出现过的问题

第二种 自己的图像分类

这也是一个方便的做法,在pytorch中提供了torchvision.datasets.ImageFolder让我们训练自己的图像。

要求:创建train和test文件夹,每个文件夹下按照类别名字存储图像就可以实现dataloader

这里还是拿上个举例子吧,实际上也可以是我们的数据集

解决pytorch读取自制数据集出现过的问题

每个下面的布局是这样的

解决pytorch读取自制数据集出现过的问题

# 预处理
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
#使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
img_data = torchvision.datasets.ImageFolder('data/cifar-10/train/', transform=transform)
data_loader = torch.utils.data.DataLoader(img_data, batch_size=4, shuffle=True, num_workers=1)
 
testset = torchvision.datasets.ImageFolder('data/cifar-10/test/', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)
 
 for epoch in range(3):
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data #trainloader返回:id,image,labels
            # 将inputs与labels装进Variable中
            inputs, labels = Variable(inputs), Variable(labels)
 
            #使用print代替输出
            print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

第三种 一维向量数据集

这个是比较尴尬的,首先我们

假设将数存储到txt等文件中,先把他读取出来,读取的部分就不仔细说了,读到一个列表里就可以

常用的可以是列表等,举例子

trainlist = []  # 保存特征的列表
 
targetpath = 'a/b/b'
filelist = os.listdir(targetpath) #列出文件夹下所有的目录与文件
filecount = len(filelist)
# 根据根路径 读取所有文件名 循环读取文件内容 添加到list
for i in range(filecount):
     filepath = os.path.join(targetpath, filelist[j])
     with open(filepath, 'r') as f:
         line = f.readline()
         # 例如存储格式为 1,2,3,4,5,6 数字之间以逗号隔开
         templist = list(map(int, line.split(',')))
         trainlist.append(templist)
 
# 数据读取完毕 现在为维度为filecount的列表 我们需要转换格式和类型
# 将数据转换为Tensor
 
# 假如我们的两类数据分别存在list0 和 list1中
split = len(list0) # 用于记录标签的分界
 
#使用numpy.array 和 torch.from_numpy 连续将其转换为tensor  使用torch.cat拼接
train0_numpy = numpy.array(list0)
train1_numpy = numpy.array(list1)
train_tensor = torch.cat([torch.from_numpy(train0_numpy), torch.from_numpytrain1_numpy)], 0)
#现在的尺寸是【样本数,长度】 然而在使用神 经网络处理一维数据要求【样本数,维度,长度】
# 这个维度指的像一个图像实际上是一个二维矩阵 但是有三个RGB通道 实际就为【3,行,列】 那么需要处理三个矩阵
# 我们需要在我们的数据中加上这个维度信息
# 注意类型要一样 可以转换
shaper = train_tensor.shape  #获取维度 【样本数,长度】
aa = torch.ones((shaper[0], 1, shaper[1])) # 生成目标矩阵
for i in range(shaper[0]):  # 将所有样本复制到新矩阵
·    aa[i][0][:] = train_tensor[i][:]
train_tensor = aa  # 完成了数据集的转换 【样本数,维度,长度】
 
# 注 意 如果是读取的图像 我们需要的目标维度是【样本数,维度,size_w,size_h】
# 卷积接受的输入是这样的四维度 最后的两个是图像的尺寸 维度表示是通道数量 
  
# 下面是生成标签 标签注意类别之间的分界 split已经在上文计算出来
# 训练标签的
total = len(list0) + len(list1)
train_label = numpy.zeros(total)
train_label[split+1:total] = 1
train_label_tensor = torch.from_numpy(train_label).int()
# print(train_tensor.size(),train_label_tensor.size())
 
# 搭建dataloader完毕
train_dataset = TensorDataset(train_tensor, train_label_tensor)
train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
 
for epoch in range(3):
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data #trainloader返回:id,image,labels
        # 将inputs与labels装进Variable中
        inputs, labels = Variable(inputs), Variable(labels)
 
        #使用print代替输出
        print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

第四种 保存路径和标签的方式创建数据集

该方法需要略微的麻烦一些,首先你有一个txt,保存了文件名和对应的标签,大概是这个意思

解决pytorch读取自制数据集出现过的问题

然后我们在程序中,根据给定的根目录找到文件,并将标签对应保存

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
	raise NotImplementedError
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

这是dataset的原本内容,getitem就是获取元素的部分,用于返回对应index的数据和标签。那么大概需要做的是我们将txt的内容读取进来,使用程序处理标签和数据

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 初始化读取txt 可以设定变换
def __init__(self, txt_path, transform = None, target_transform = None):
	fh = open(txt_path, 'r')
	imgs = []
	for line in fh:
		line = line.rstrip()
		words = line.split()
         # 保存列表 其中有图像的数据 和标签
		imgs.append((words[0], int(words[1])))
		self.imgs = imgs 
		self.transform = transform
		self.target_transform = target_transform
def __getitem__(self, index):
	fn, label = self.imgs[index]
	img = Image.open(fn).convert('RGB') 
	if self.transform is not None:
		img = self.transform(img) 
    # 返回图像和标签
    
	return img, label
def __len__(self):
	return len(self.imgs)
 
# 当然也可以创建myImageFloder 其txt格式在下图显示 
import os
import torch
import torch.utils.data as data
from PIL import Image 
def default_loader(path):
    return Image.open(path).convert('RGB')
 
class myImageFloder(data.Dataset):
    def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
        fh = open(label) #打开label文件
        c=0
        imgs=[]  # 保存图像的列表
        class_names=[]
        for line in  fh.readlines(): #读取每一行数据
            if c==0:
                class_names=[n.strip() for n in line.rstrip().split('	')] 
            else:
                cls = line.split() #分割为列表
                fn = cls.pop(0)  #弹出最上的一个
                if os.path.isfile(os.path.join(root, fn)):  # 组合路径名 读取图像
                    imgs.append((fn, tuple([float(v) for v in cls])))  #添加到列表
            c=c+1
 
        # 设置信息
        self.root = root
        self.imgs = imgs
        self.classes = class_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
 
    def __getitem__(self, index):  # 获取图像 给定序号
        fn, label = self.imgs[index]  #读取图像的内容和对应的label
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:  # 是否变换
            img = self.transform(img)
        return img, torch.Tensor(label) # 返回图像和label
 
    def __len__(self):
        return len(self.imgs)
    
    def getName(self):
        return self.classes
#

解决pytorch读取自制数据集出现过的问题

# 而后使用的时候就可以正常的使用
trainset = MyDataset(txt_path=pathFile,transform = None, target_transform = None)
# trainset = torch.utils.data.DataLoader(myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), batch_size= 2, shuffle= False, num_workers= 2)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)

它的要点是,继承dataset,在初始化中处理txt文本数据,保存对应的数据,并实现对应的功能。

这其中的原理就是如此,但是注意可能有些许略微不恰当的地方,可能就需要到时候现场调试了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
布同 Python中文问题解决方法(总结了多位前人经验,初学者必看)
Mar 13 Python
python登陆asp网站页面的实现代码
Jan 14 Python
用Python设计一个经典小游戏
May 15 Python
python中requests爬去网页内容出现乱码问题解决方法介绍
Oct 25 Python
django实现登录时候输入密码错误5次锁定用户十分钟
Nov 05 Python
如何在python中使用selenium的示例
Dec 26 Python
python实现自动登录后台管理系统
Oct 18 Python
Python实现简易过滤删除数字的方法小结
Jan 09 Python
对python_discover方法遍历所有执行的用例详解
Feb 13 Python
python实例化对象的具体方法
Jun 17 Python
Python txt文件常用读写操作代码实例
Aug 03 Python
聊聊python在linux下与windows下导入模块的区别说明
Mar 03 Python
Python爬虫基础初探selenium
只用40行Python代码就能写出pdf转word小工具
pytorch 如何把图像数据集进行划分成train,test和val
May 31 #Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
You might like
PHP 模拟登陆MSN并获得用户信息
2009/05/16 PHP
Php获取金书网的书名的实现代码
2010/06/11 PHP
基于CakePHP实现的简单博客系统实例
2015/06/28 PHP
php PDO异常处理详解
2016/11/20 PHP
JavaScript表单常用验证集合
2008/01/16 Javascript
Jquery选择器中使用变量实现动态选择例子
2014/07/25 Javascript
原生Javascript封装的一个AJAX函数分享
2014/10/11 Javascript
jquery实现一个简单的表单验证实例
2016/03/30 Javascript
使用get方式提交表单在地址栏里面不显示提交信息
2017/02/21 Javascript
浅析Angular 实现一个repeat指令的方法
2019/07/21 Javascript
JS实现简单随机3D骰子
2019/10/24 Javascript
微信小程序实现签字功能
2019/12/23 Javascript
24个解决实际问题的ES6代码片段(小结)
2020/02/02 Javascript
详解vuejs中执行npm run dev出现页面cannot GET/问题
2020/04/26 Javascript
vue+AI智能机器人回复功能实现
2020/07/16 Javascript
JavaScript缓动动画函数的封装方法
2020/11/25 Javascript
[01:05:52]DOTA2-DPC中国联赛 正赛 Ehome vs Aster BO3 第一场 2月2日
2021/03/11 DOTA
Python的字典和列表的使用中一些需要注意的地方
2015/04/24 Python
Python调用SQLPlus来操作和解析Oracle数据库的方法
2016/04/09 Python
python实现批量监控网站
2016/09/09 Python
一个月入门Python爬虫学习,轻松爬取大规模数据
2018/01/03 Python
python 把文件中的每一行以数组的元素放入数组中的方法
2018/04/29 Python
python pandas实现excel转为html格式的方法
2018/10/23 Python
解决Pycharm下面出现No R interpreter defined的问题
2018/10/29 Python
Python中的self用法详解
2019/08/06 Python
python基于K-means聚类算法的图像分割
2019/10/30 Python
Python imutils 填充图片周边为黑色的实现
2020/01/19 Python
django模型动态修改参数,增加 filter 字段的方式
2020/03/16 Python
CSS3教程:新增加的结构伪类
2009/04/02 HTML / CSS
高山背包:High Sierra
2017/11/23 全球购物
中国跨境电子商务网站:NewFrog
2018/03/10 全球购物
营业员个人总结的自我评价
2013/10/25 职场文书
公务员职业生涯规划书范文  
2014/01/19 职场文书
个人欠款担保书
2014/05/20 职场文书
身边的榜样活动方案
2014/08/20 职场文书
python3 实现mysql数据库连接池的示例代码
2021/04/17 Python