解决pytorch读取自制数据集出现过的问题


Posted in Python onMay 31, 2021

问题1

问题描述:

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

解决方式

数据格式不对, 把image转成tensor,参数transform进行如下设置就可以了:transform=transform.ToTensor()。注意检测一下transform

问题2

问题描述:

TypeError: append() takes exactly one argument (2 given)

出现问题的地方

imgs.append(words[0], int(words[1]))

解决方式

加括号,如下

imgs.append((words[0], int(words[1])))

问题3

问题描述

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

解决方式

数据和模型不在同一设备上,应该要么都在GPU运行,要么都在CPU

问题4

问题描述

RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[1, 3, 512, 512] to have 1 channels, but got 3 channels instead

解决方式

图像竟然是RGB,但我的训练图像是一通道的灰度图,所以得想办法把 mode 转换成灰度图L

补充:神经网络 pytorch 数据集读取(自动读取数据集,手动读取自己的数据)

对于pytorch,我们有现成的包装好的数据集可以使用,也可以自己创建自己的数据集,大致来说有三种方法,这其中用到的两个包是datasets和DataLoader

datasets:用于将数据和标签打包成数据集

DataLoader:用于对数据集的高级处理,比如分组,打乱,处理等,在训练和测试中可以直接使用DataLoader进行处理

第一种 现成的打包数据集

这种比较简答,只需要现成的几行代码和一个路径就可以完成,但是一般都是常用比如cifar-10

解决pytorch读取自制数据集出现过的问题

对于常用数据集,可以使用torchvision.datasets直接进行读取,这是对其常用的处理,该类也是继承于torch.utils.data.Dataset。

#是第一次运行的话会下载数据集 现成的话可以使用root参数指定数据集位置
# 存放的格式如下图
 
# 根据接口读取默认的CIFAR10数据 进行训练和测试
#预处理
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#读取数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
#打包成DataLoader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=1)
 
#同上
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=1)
classes = (1,2,3,4,5,6,7,8,9,10)  #类别定义
 
#使用
 for epoch in range(3):
        running_loss = 0.0 #清空loss
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data #trainloader返回:id,image,labels
 
            # 将inputs与labels装进Variable中
            inputs, labels = Variable(inputs), Variable(labels)
            
            #使用print代替输出
            print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

解决pytorch读取自制数据集出现过的问题

第二种 自己的图像分类

这也是一个方便的做法,在pytorch中提供了torchvision.datasets.ImageFolder让我们训练自己的图像。

要求:创建train和test文件夹,每个文件夹下按照类别名字存储图像就可以实现dataloader

这里还是拿上个举例子吧,实际上也可以是我们的数据集

解决pytorch读取自制数据集出现过的问题

每个下面的布局是这样的

解决pytorch读取自制数据集出现过的问题

# 预处理
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
#使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
img_data = torchvision.datasets.ImageFolder('data/cifar-10/train/', transform=transform)
data_loader = torch.utils.data.DataLoader(img_data, batch_size=4, shuffle=True, num_workers=1)
 
testset = torchvision.datasets.ImageFolder('data/cifar-10/test/', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)
 
 for epoch in range(3):
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data #trainloader返回:id,image,labels
            # 将inputs与labels装进Variable中
            inputs, labels = Variable(inputs), Variable(labels)
 
            #使用print代替输出
            print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

第三种 一维向量数据集

这个是比较尴尬的,首先我们

假设将数存储到txt等文件中,先把他读取出来,读取的部分就不仔细说了,读到一个列表里就可以

常用的可以是列表等,举例子

trainlist = []  # 保存特征的列表
 
targetpath = 'a/b/b'
filelist = os.listdir(targetpath) #列出文件夹下所有的目录与文件
filecount = len(filelist)
# 根据根路径 读取所有文件名 循环读取文件内容 添加到list
for i in range(filecount):
     filepath = os.path.join(targetpath, filelist[j])
     with open(filepath, 'r') as f:
         line = f.readline()
         # 例如存储格式为 1,2,3,4,5,6 数字之间以逗号隔开
         templist = list(map(int, line.split(',')))
         trainlist.append(templist)
 
# 数据读取完毕 现在为维度为filecount的列表 我们需要转换格式和类型
# 将数据转换为Tensor
 
# 假如我们的两类数据分别存在list0 和 list1中
split = len(list0) # 用于记录标签的分界
 
#使用numpy.array 和 torch.from_numpy 连续将其转换为tensor  使用torch.cat拼接
train0_numpy = numpy.array(list0)
train1_numpy = numpy.array(list1)
train_tensor = torch.cat([torch.from_numpy(train0_numpy), torch.from_numpytrain1_numpy)], 0)
#现在的尺寸是【样本数,长度】 然而在使用神 经网络处理一维数据要求【样本数,维度,长度】
# 这个维度指的像一个图像实际上是一个二维矩阵 但是有三个RGB通道 实际就为【3,行,列】 那么需要处理三个矩阵
# 我们需要在我们的数据中加上这个维度信息
# 注意类型要一样 可以转换
shaper = train_tensor.shape  #获取维度 【样本数,长度】
aa = torch.ones((shaper[0], 1, shaper[1])) # 生成目标矩阵
for i in range(shaper[0]):  # 将所有样本复制到新矩阵
·    aa[i][0][:] = train_tensor[i][:]
train_tensor = aa  # 完成了数据集的转换 【样本数,维度,长度】
 
# 注 意 如果是读取的图像 我们需要的目标维度是【样本数,维度,size_w,size_h】
# 卷积接受的输入是这样的四维度 最后的两个是图像的尺寸 维度表示是通道数量 
  
# 下面是生成标签 标签注意类别之间的分界 split已经在上文计算出来
# 训练标签的
total = len(list0) + len(list1)
train_label = numpy.zeros(total)
train_label[split+1:total] = 1
train_label_tensor = torch.from_numpy(train_label).int()
# print(train_tensor.size(),train_label_tensor.size())
 
# 搭建dataloader完毕
train_dataset = TensorDataset(train_tensor, train_label_tensor)
train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
 
for epoch in range(3):
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data #trainloader返回:id,image,labels
        # 将inputs与labels装进Variable中
        inputs, labels = Variable(inputs), Variable(labels)
 
        #使用print代替输出
        print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

第四种 保存路径和标签的方式创建数据集

该方法需要略微的麻烦一些,首先你有一个txt,保存了文件名和对应的标签,大概是这个意思

解决pytorch读取自制数据集出现过的问题

然后我们在程序中,根据给定的根目录找到文件,并将标签对应保存

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
	raise NotImplementedError
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

这是dataset的原本内容,getitem就是获取元素的部分,用于返回对应index的数据和标签。那么大概需要做的是我们将txt的内容读取进来,使用程序处理标签和数据

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 初始化读取txt 可以设定变换
def __init__(self, txt_path, transform = None, target_transform = None):
	fh = open(txt_path, 'r')
	imgs = []
	for line in fh:
		line = line.rstrip()
		words = line.split()
         # 保存列表 其中有图像的数据 和标签
		imgs.append((words[0], int(words[1])))
		self.imgs = imgs 
		self.transform = transform
		self.target_transform = target_transform
def __getitem__(self, index):
	fn, label = self.imgs[index]
	img = Image.open(fn).convert('RGB') 
	if self.transform is not None:
		img = self.transform(img) 
    # 返回图像和标签
    
	return img, label
def __len__(self):
	return len(self.imgs)
 
# 当然也可以创建myImageFloder 其txt格式在下图显示 
import os
import torch
import torch.utils.data as data
from PIL import Image 
def default_loader(path):
    return Image.open(path).convert('RGB')
 
class myImageFloder(data.Dataset):
    def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
        fh = open(label) #打开label文件
        c=0
        imgs=[]  # 保存图像的列表
        class_names=[]
        for line in  fh.readlines(): #读取每一行数据
            if c==0:
                class_names=[n.strip() for n in line.rstrip().split('	')] 
            else:
                cls = line.split() #分割为列表
                fn = cls.pop(0)  #弹出最上的一个
                if os.path.isfile(os.path.join(root, fn)):  # 组合路径名 读取图像
                    imgs.append((fn, tuple([float(v) for v in cls])))  #添加到列表
            c=c+1
 
        # 设置信息
        self.root = root
        self.imgs = imgs
        self.classes = class_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
 
    def __getitem__(self, index):  # 获取图像 给定序号
        fn, label = self.imgs[index]  #读取图像的内容和对应的label
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:  # 是否变换
            img = self.transform(img)
        return img, torch.Tensor(label) # 返回图像和label
 
    def __len__(self):
        return len(self.imgs)
    
    def getName(self):
        return self.classes
#

解决pytorch读取自制数据集出现过的问题

# 而后使用的时候就可以正常的使用
trainset = MyDataset(txt_path=pathFile,transform = None, target_transform = None)
# trainset = torch.utils.data.DataLoader(myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), batch_size= 2, shuffle= False, num_workers= 2)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)

它的要点是,继承dataset,在初始化中处理txt文本数据,保存对应的数据,并实现对应的功能。

这其中的原理就是如此,但是注意可能有些许略微不恰当的地方,可能就需要到时候现场调试了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python函数学习笔记
Oct 07 Python
python类和继承用法实例
Jul 07 Python
Python正规则表达式学习指南
Aug 02 Python
python数据封装json格式数据
Mar 04 Python
Python打开文件,将list、numpy数组内容写入txt文件中的方法
Oct 26 Python
Django如何实现上传图片功能
Aug 16 Python
Python类继承和多态原理解析
Feb 05 Python
TensorFlow实现从txt文件读取数据
Feb 05 Python
Python接口测试get请求过程详解
Feb 28 Python
python实现文法左递归的消除方法
May 22 Python
Python 如何创建一个线程池
Jul 28 Python
详解pandas映射与数据转换
Jan 22 Python
Python爬虫基础初探selenium
只用40行Python代码就能写出pdf转word小工具
pytorch 如何把图像数据集进行划分成train,test和val
May 31 #Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
You might like
PHP 最大运行时间 max_execution_time修改方法
2010/03/08 PHP
PHP基于ORM方式操作MySQL数据库实例
2017/06/21 PHP
jquery分页插件jpaginate在IE中不兼容问题
2014/04/22 Javascript
iframe如何动态创建及释放其所占内存
2014/09/03 Javascript
javascript变量声明实例分析
2015/04/25 Javascript
基于jQuery实现收缩展开功能
2016/03/18 Javascript
BootstrapTable请求数据时设置超时(timeout)的方法
2017/01/22 Javascript
canvas 绘制圆形时钟
2017/02/22 Javascript
Bootstrap Table使用整理(四)之工具栏
2017/06/09 Javascript
Vue.js项目部署到服务器的详细步骤
2017/07/17 Javascript
vue生成随机验证码的示例代码
2017/09/29 Javascript
JavaScript判断变量名是否存在数组中的实例
2017/12/28 Javascript
JavaScript数组去重算法实例小结
2018/05/07 Javascript
vue-cli项目根据线上环境分别打出测试包和生产包
2018/05/23 Javascript
vue移动端html5页面根据屏幕适配的四种解决方法
2018/10/19 Javascript
微信小程序列表时间戳转换实现过程解析
2019/10/12 Javascript
详解小程序如何改变onLoad的执行时机
2019/11/01 Javascript
vue使用lodop打印控件实现浏览器兼容打印的方法
2021/02/07 Vue.js
[02:52]2017DOTA2国际邀请赛中国区预选赛晋级之路
2017/07/03 DOTA
Python不规范的日期字符串处理类
2014/06/10 Python
基于Django模板中的数字自增(详解)
2017/09/05 Python
Django rest framework工具包简单用法示例
2018/07/20 Python
python删除字符串中指定字符的方法
2018/08/13 Python
如何利用python制作时间戳转换工具详解
2018/09/12 Python
Python基于plotly模块实现的画图操作示例
2019/01/23 Python
Django 简单实现分页与搜索功能的示例代码
2019/11/07 Python
Python面向对象封装操作案例详解
2019/12/31 Python
Python Dataframe常见索引方式详解
2020/05/27 Python
如何用Python 加密文件
2020/09/10 Python
新西兰最大的在线设计师眼镜店:SmartBuyGlasses新西兰
2017/10/20 全球购物
古驰英国官网:GUCCI英国
2020/03/07 全球购物
高中毕业生自我鉴定
2013/11/03 职场文书
活动总结书怎么写
2015/05/11 职场文书
2019销售早会主持词
2019/06/27 职场文书
2019大学生社会实践报告汇总
2019/08/16 职场文书
Python实现信息管理系统
2022/06/05 Python