pytorch加载自己的图像数据集实例


Posted in Python onJuly 07, 2020

之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码。到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题。

参考文章https://3water.com/article/177613.htm

下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor。最后读取第一张图片并显示。

# 数据处理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms

transform = transforms.Compose([
 transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
 # transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1,1]
])

#定义自己的数据集合
class FlameSet(data.Dataset):
 def __init__(self,root):
  # 所有图片的绝对路径
  imgs=os.listdir(root)
  self.imgs=[os.path.join(root,k) for k in imgs]
  self.transforms=transform

 def __getitem__(self, index):
  img_path = self.imgs[index]
  pil_img = Image.open(img_path)
  if self.transforms:
   data = self.transforms(pil_img)
  else:
   pil_img = np.asarray(pil_img)
   data = torch.from_numpy(pil_img)
  return data

 def __len__(self):
  return len(self.imgs)

if __name__ == '__main__':
 dataSet=FlameSet('./test')
 print(dataSet[0])

显示结果:

pytorch加载自己的图像数据集实例

补充知识:使用Pytorch进行读取本地的MINIST数据集并进行装载

pytorch中的torchvision.datasets中自带MINIST数据集,可直接调用模块进行获取,也可以进行自定义自己的Dataset类进行读取本地数据和初始化数据。

1. 直接使用pytorch自带的MNIST进行下载:

缺点: 下载速度较慢,而且如果中途下载失败一般得是重新进行执行代码进行下载:

# # 训练数据和测试数据的下载
# 训练数据和测试数据的下载
trainDataset = torchvision.datasets.MNIST( # torchvision可以实现数据集的训练集和测试集的下载
  root="./data", # 下载数据,并且存放在data文件夹中
  train=True, # train用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分;如果设置为False,则说明载入的是该数据集的测试集部分。
  transform=transforms.ToTensor(), # 数据的标准化等操作都在transforms中,此处是转换
  download=True # 瞎子啊过程中如果中断,或者下载完成之后再次运行,则会出现报错
)

testDataset = torchvision.datasets.MNIST(
  root="./data",
  train=False,
  transform=transforms.ToTensor(),
  download=True
)

2. 自定义dataset类进行数据的读取以及初始化。

其中自己下载的MINIST数据集的内容如下:

pytorch加载自己的图像数据集实例

自己定义的dataset类需要继承: Dataset

需要实现必要的魔法方法:

__init__魔法方法里面进行读取数据文件

__getitem__魔法方法进行支持下标访问

__len__魔法方法返回自定义数据集的大小,方便后期遍历

示例如下:

class DealDataset(Dataset):
  """
    读取数据、初始化数据
  """
  def __init__(self, folder, data_name, label_name,transform=None):
    (train_set, train_labels) = load_minist_data.load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
    self.train_set = train_set
    self.train_labels = train_labels
    self.transform = transform

  def __getitem__(self, index):

    img, target = self.train_set[index], int(self.train_labels[index])
    if self.transform is not None:
      img = self.transform(img)
    return img, target

  def __len__(self):
    return len(self.train_set)

其中load_minist_data.load_data也是我们自己写的读取数据文件的函数,即放在了load_minist_data.py中的load_data函数中。具体实现如下:

def load_data(data_folder, data_name, label_name):
 """
    data_folder: 文件目录
    data_name: 数据文件名
    label_name:标签数据文件名
  """
 with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
  y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

 with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:
  x_train = np.frombuffer(
    imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
 return (x_train, y_train)

编写完自定义的dataset就可以进行实例化该类并装载数据:

# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset('MNIST_data/', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())
testDataset = DealDataset('MNIST_data/', "t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz",transform=transforms.ToTensor())

# 训练数据和测试数据的装载
train_loader = dataloader.DataLoader(
  dataset=trainDataset,
  batch_size=100, # 一个批次可以认为是一个包,每个包中含有100张图片
  shuffle=False,
)

test_loader = dataloader.DataLoader(
  dataset=testDataset,
  batch_size=100,
  shuffle=False,
)

构建简单的神经网络并进行训练和测试:

class NeuralNet(nn.Module):

  def __init__(self, input_num, hidden_num, output_num):
    super(NeuralNet, self).__init__()
    self.fc1 = nn.Linear(input_num, hidden_num)
    self.fc2 = nn.Linear(hidden_num, output_num)
    self.relu = nn.ReLU()

  def forward(self,x):
    x = self.fc1(x)
    x = self.relu(x)
    y = self.fc2(x)
    return y

# 参数初始化
epoches = 5
lr = 0.001
input_num = 784
hidden_num = 500
output_num = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 产生训练模型对象以及定义损失函数和优化函数
model = NeuralNet(input_num, hidden_num, output_num)
model.to(device)
criterion = nn.CrossEntropyLoss() # 使用交叉熵作为损失函数
optimizer = optim.Adam(model.parameters(), lr=lr)

# 开始循环训练
for epoch in range(epoches): # 一个epoch可以认为是一次训练循环
  for i, data in enumerate(train_loader):
    (images, labels) = data
    images = images.reshape(-1, 28*28).to(device)
    labels = labels.to(device)
    output = model(images) # 经过模型对象就产生了输出
    loss = criterion(output, labels.long()) # 传入的参数: 输出值(预测值), 实际值(标签)
    optimizer.zero_grad() # 梯度清零
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0: # i表示样本的编号
      print('Epoch [{}/{}], Loss: {:.4f}'
         .format(epoch + 1, epoches, loss.item())) # {}里面是后面需要传入的变量
                              # loss.item
# 开始测试
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
    images = images.reshape(-1, 28*28).to(device) # 此处的-1一般是指自动匹配的意思, 即不知道有多少行,但是确定了列数为28 * 28
                           # 其实由于此处28 * 28本身就已经等于了原tensor的大小,所以,行数也就确定了,为1
    labels = labels.to(device)
    output = model(images)
    _, predicted = torch.max(output, 1)
    total += labels.size(0) # 此处的size()类似numpy的shape: np.shape(train_images)[0]
    correct += (predicted == labels).sum().item()
  print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

以上这篇pytorch加载自己的图像数据集实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中除法使用的注意事项
Aug 21 Python
详解Python中open()函数指定文件打开方式的用法
Jun 04 Python
python 简单的多线程链接实现代码
Aug 28 Python
Flask解决跨域的问题示例代码
Feb 12 Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 Python
Python wxpython模块响应鼠标拖动事件操作示例
Aug 23 Python
使用python批量化音乐文件格式转换的实例
Jan 09 Python
Python3+OpenCV2实现图像的几何变换(平移、镜像、缩放、旋转、仿射)
May 13 Python
Python 使用PyQt5 完成选择文件或目录的对话框方法
Jun 27 Python
PyCharm2018 安装及破解方法实现步骤
Sep 09 Python
Python语法之精妙的十个知识点(装B语法)
Jan 18 Python
Python访问Redis的详细操作
Jun 26 Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
keras实现VGG16方式(预测一张图片)
Jul 07 #Python
通过实例解析Python RPC实现原理及方法
Jul 07 #Python
Keras预训练的ImageNet模型实现分类操作
Jul 07 #Python
Scrapy模拟登录赶集网的实现代码
Jul 07 #Python
You might like
php设计模式 Facade(外观模式)
2011/06/26 PHP
PHP采集类snoopy详细介绍(snoopy使用教程)
2014/06/19 PHP
PHP临时文件的安全性分析
2014/07/04 PHP
phplist及phpmailer(组合使用)通过gmail发送邮件的配置方法
2016/03/30 PHP
支持汉转拼和拼音分词的PHP中文工具类ChineseUtil
2018/02/23 PHP
windows8.1+iis8.5下安装node.js开发环境
2014/12/12 Javascript
javascript 获取浏览器版本
2015/01/21 Javascript
javascript中in运算符用法分析
2015/04/28 Javascript
JS+Canvas 实现下雨下雪效果
2016/05/18 Javascript
jQuery插件pagination实现无刷新分页
2016/05/21 Javascript
浅谈Vuejs Prop基本用法
2017/08/17 Javascript
VUE饿了么树形控件添加增删改功能的示例代码
2017/10/17 Javascript
JavaScript实现的拼图算法分析
2019/02/13 Javascript
解决Vue+Electron下Vuex的Dispatch没有效果问题
2019/05/20 Javascript
vue-cli随机生成port源码的方法
2019/09/02 Javascript
使用preload预加载页面资源时注意事项
2020/02/03 Javascript
Vue页面跳转传递参数及接收方式
2020/09/09 Javascript
antd form表单数据回显操作
2020/11/02 Javascript
python完成FizzBuzzWhizz问题(拉勾网面试题)示例
2014/05/05 Python
python实现在无须过多援引的情况下创建字典的方法
2014/09/25 Python
Python实现单词拼写检查
2015/04/25 Python
django轻松使用富文本编辑器CKEditor的方法
2017/03/30 Python
Python基于numpy灵活定义神经网络结构的方法
2017/08/19 Python
python pandas dataframe 行列选择,切片操作方法
2018/04/10 Python
对Python 检查文件名是否规范的实例详解
2019/06/10 Python
解决import tensorflow导致jupyter内核死亡的问题
2021/02/06 Python
Maje德国官网:法国女性成衣品牌
2017/02/10 全球购物
Steiff台湾官网:德国金耳釦泰迪熊
2019/12/26 全球购物
如果重写了对象的equals()方法,需要考虑什么
2014/11/02 面试题
药物学专业学生的自我评价
2013/10/27 职场文书
红旗团支部事迹材料
2014/01/27 职场文书
员工拓展培训方案
2014/02/15 职场文书
质量安全标语
2014/06/07 职场文书
群众路线自我剖析材料
2014/10/08 职场文书
2016优秀护士先进个人事迹材料
2016/02/25 职场文书
HR必备:超全面的薪酬待遇管理方案!
2019/07/12 职场文书