pytorch加载自己的图像数据集实例


Posted in Python onJuly 07, 2020

之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码。到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题。

参考文章https://3water.com/article/177613.htm

下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor。最后读取第一张图片并显示。

# 数据处理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms

transform = transforms.Compose([
 transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
 # transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1,1]
])

#定义自己的数据集合
class FlameSet(data.Dataset):
 def __init__(self,root):
  # 所有图片的绝对路径
  imgs=os.listdir(root)
  self.imgs=[os.path.join(root,k) for k in imgs]
  self.transforms=transform

 def __getitem__(self, index):
  img_path = self.imgs[index]
  pil_img = Image.open(img_path)
  if self.transforms:
   data = self.transforms(pil_img)
  else:
   pil_img = np.asarray(pil_img)
   data = torch.from_numpy(pil_img)
  return data

 def __len__(self):
  return len(self.imgs)

if __name__ == '__main__':
 dataSet=FlameSet('./test')
 print(dataSet[0])

显示结果:

pytorch加载自己的图像数据集实例

补充知识:使用Pytorch进行读取本地的MINIST数据集并进行装载

pytorch中的torchvision.datasets中自带MINIST数据集,可直接调用模块进行获取,也可以进行自定义自己的Dataset类进行读取本地数据和初始化数据。

1. 直接使用pytorch自带的MNIST进行下载:

缺点: 下载速度较慢,而且如果中途下载失败一般得是重新进行执行代码进行下载:

# # 训练数据和测试数据的下载
# 训练数据和测试数据的下载
trainDataset = torchvision.datasets.MNIST( # torchvision可以实现数据集的训练集和测试集的下载
  root="./data", # 下载数据,并且存放在data文件夹中
  train=True, # train用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分;如果设置为False,则说明载入的是该数据集的测试集部分。
  transform=transforms.ToTensor(), # 数据的标准化等操作都在transforms中,此处是转换
  download=True # 瞎子啊过程中如果中断,或者下载完成之后再次运行,则会出现报错
)

testDataset = torchvision.datasets.MNIST(
  root="./data",
  train=False,
  transform=transforms.ToTensor(),
  download=True
)

2. 自定义dataset类进行数据的读取以及初始化。

其中自己下载的MINIST数据集的内容如下:

pytorch加载自己的图像数据集实例

自己定义的dataset类需要继承: Dataset

需要实现必要的魔法方法:

__init__魔法方法里面进行读取数据文件

__getitem__魔法方法进行支持下标访问

__len__魔法方法返回自定义数据集的大小,方便后期遍历

示例如下:

class DealDataset(Dataset):
  """
    读取数据、初始化数据
  """
  def __init__(self, folder, data_name, label_name,transform=None):
    (train_set, train_labels) = load_minist_data.load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
    self.train_set = train_set
    self.train_labels = train_labels
    self.transform = transform

  def __getitem__(self, index):

    img, target = self.train_set[index], int(self.train_labels[index])
    if self.transform is not None:
      img = self.transform(img)
    return img, target

  def __len__(self):
    return len(self.train_set)

其中load_minist_data.load_data也是我们自己写的读取数据文件的函数,即放在了load_minist_data.py中的load_data函数中。具体实现如下:

def load_data(data_folder, data_name, label_name):
 """
    data_folder: 文件目录
    data_name: 数据文件名
    label_name:标签数据文件名
  """
 with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
  y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

 with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:
  x_train = np.frombuffer(
    imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
 return (x_train, y_train)

编写完自定义的dataset就可以进行实例化该类并装载数据:

# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset('MNIST_data/', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())
testDataset = DealDataset('MNIST_data/', "t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz",transform=transforms.ToTensor())

# 训练数据和测试数据的装载
train_loader = dataloader.DataLoader(
  dataset=trainDataset,
  batch_size=100, # 一个批次可以认为是一个包,每个包中含有100张图片
  shuffle=False,
)

test_loader = dataloader.DataLoader(
  dataset=testDataset,
  batch_size=100,
  shuffle=False,
)

构建简单的神经网络并进行训练和测试:

class NeuralNet(nn.Module):

  def __init__(self, input_num, hidden_num, output_num):
    super(NeuralNet, self).__init__()
    self.fc1 = nn.Linear(input_num, hidden_num)
    self.fc2 = nn.Linear(hidden_num, output_num)
    self.relu = nn.ReLU()

  def forward(self,x):
    x = self.fc1(x)
    x = self.relu(x)
    y = self.fc2(x)
    return y

# 参数初始化
epoches = 5
lr = 0.001
input_num = 784
hidden_num = 500
output_num = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 产生训练模型对象以及定义损失函数和优化函数
model = NeuralNet(input_num, hidden_num, output_num)
model.to(device)
criterion = nn.CrossEntropyLoss() # 使用交叉熵作为损失函数
optimizer = optim.Adam(model.parameters(), lr=lr)

# 开始循环训练
for epoch in range(epoches): # 一个epoch可以认为是一次训练循环
  for i, data in enumerate(train_loader):
    (images, labels) = data
    images = images.reshape(-1, 28*28).to(device)
    labels = labels.to(device)
    output = model(images) # 经过模型对象就产生了输出
    loss = criterion(output, labels.long()) # 传入的参数: 输出值(预测值), 实际值(标签)
    optimizer.zero_grad() # 梯度清零
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0: # i表示样本的编号
      print('Epoch [{}/{}], Loss: {:.4f}'
         .format(epoch + 1, epoches, loss.item())) # {}里面是后面需要传入的变量
                              # loss.item
# 开始测试
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
    images = images.reshape(-1, 28*28).to(device) # 此处的-1一般是指自动匹配的意思, 即不知道有多少行,但是确定了列数为28 * 28
                           # 其实由于此处28 * 28本身就已经等于了原tensor的大小,所以,行数也就确定了,为1
    labels = labels.to(device)
    output = model(images)
    _, predicted = torch.max(output, 1)
    total += labels.size(0) # 此处的size()类似numpy的shape: np.shape(train_images)[0]
    correct += (predicted == labels).sum().item()
  print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

以上这篇pytorch加载自己的图像数据集实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之对象和类的实际运用
Aug 29 Python
python中文编码问题小结
Sep 28 Python
python采用getopt解析命令行输入参数实例
Sep 30 Python
Python入门篇之对象类型
Oct 17 Python
简介Python设计模式中的代理模式与模板方法模式编程
Feb 02 Python
Python 实现数据库(SQL)更新脚本的生成方法
Jul 09 Python
python实现rsa加密实例详解
Jul 19 Python
python实现维吉尼亚算法
Mar 20 Python
python求最大值最小值方法总结
Jun 25 Python
Python编程快速上手——Excel到CSV的转换程序案例分析
Feb 28 Python
Django serializer优化类视图的实现示例
Jul 16 Python
python 解决selenium 中的 .clear()方法失效问题
Sep 01 Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
keras实现VGG16方式(预测一张图片)
Jul 07 #Python
通过实例解析Python RPC实现原理及方法
Jul 07 #Python
Keras预训练的ImageNet模型实现分类操作
Jul 07 #Python
Scrapy模拟登录赶集网的实现代码
Jul 07 #Python
You might like
PHP伪静态写法附代码
2008/06/20 PHP
PHP中防止SQL注入攻击和XSS攻击的两个简单方法
2010/04/15 PHP
PHP实现数组的笛卡尔积运算示例
2017/12/15 PHP
PHP设计模式之适配器模式原理与用法分析
2018/04/25 PHP
PHP实现转盘抽奖算法分享
2020/04/15 PHP
php常用字符串长度函数strlen()与mb_strlen()用法实例分析
2019/06/25 PHP
jQuery 研究心得 取得属性的值
2007/11/30 Javascript
jquery 得到当前页面高度和宽度的两个函数
2010/02/21 Javascript
不要在cookie中使用特殊字符的原因分析
2010/07/13 Javascript
Javascript中判断变量是数组还是对象(array还是object)
2013/08/14 Javascript
js面向对象编程之如何实现方法重载
2014/07/02 Javascript
js实现对table动态添加、删除和更新的方法
2015/02/10 Javascript
每天一篇javascript学习小结(RegExp对象)
2015/11/17 Javascript
JavaScript自学笔记(必看篇)
2016/06/23 Javascript
JavaScript 随机验证码的生成实例代码
2016/09/22 Javascript
KnockoutJS 3.X API 第四章之表单textInput、hasFocus、checked绑定
2016/10/11 Javascript
node中koa中间件机制详解
2017/08/22 Javascript
vue.js实例对象+组件树的详细介绍
2017/10/20 Javascript
Node.js从字符串生成文件流的实现方法
2019/08/18 Javascript
JS前端广告拦截实现原理解析
2020/02/17 Javascript
JavaScript 俄罗斯方块游戏实现方法与代码解释
2020/04/08 Javascript
[50:58]2018DOTA2亚洲邀请赛 4.1 小组赛 B组 Mineski vs EG
2018/04/03 DOTA
在类Unix系统上开始Python3编程入门
2015/08/20 Python
实例讲解Python中SocketServer模块处理网络请求的用法
2016/06/28 Python
python爬虫 模拟登录人人网过程解析
2019/07/31 Python
python写入数据到csv或xlsx文件的3种方法
2019/08/23 Python
英国女士和男士时尚服装网上购物:Top Labels Online
2018/03/25 全球购物
年度考核自我鉴定
2014/02/02 职场文书
升旗仪式主持词
2014/03/19 职场文书
运动会横幅标语
2014/06/17 职场文书
天猫活动策划方案
2014/08/21 职场文书
离婚案件答辩状
2015/05/22 职场文书
2015年秋季运动会加油稿
2015/07/22 职场文书
JavaScript组合继承详解
2021/11/07 Javascript
Python中的pprint模块
2021/11/27 Python
天谕手游15杯全调酒配方和调酒券的获得方式
2022/04/06 其他游戏