Pytorch 实现数据集自定义读取


Posted in Python onJanuary 18, 2020

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
  #out = np.empty(idx.shape, dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#数据类型转换
  end = time.time()
  return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
  def __init__(self, root, is_train, crop_size=(320,480)):
    self.rgb_mean =(0.485, 0.456, 0.406)
    self.rgb_std = (0.229, 0.224, 0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#创建空列表存文件名称
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
      self.images = f.read().split()
    #数据名称整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,
        "label": label_file,
        "name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整个循环的意思就是将颜色标签映射为单通道的数组索引
    for i, cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引读取每个元素的具体内容
  def __getitem__(self, index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
    #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),
       transforms.ToTensor(),
       transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
    
    #将颜色标签图转为序号标签图
    mylabel=voc_label_indices(croplabel, self.colormap2label)
    
    return cropImage,mylabel
  #返回图像数据长度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
  getimgs, labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor转numpy
  labels=labels[0]#获得批次标签集中的一张标签图像
  labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

  ##将单通道索引标签图片映射回颜色标签图片
  newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
  for i in range(0, 480):
    for j in range(0, 320):
      sele=labels[i][j]#取得坐标点对应像素的值
      newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

  #显示图像和标签
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上这篇Pytorch 实现数据集自定义读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用PyFetion来发送短信的例子
Apr 22 Python
python实现端口转发器的方法
Mar 13 Python
Python实例一个类背后发生了什么
Feb 09 Python
对pandas的算术运算和数据对齐实例详解
Dec 22 Python
pycharm打开命令行或Terminal的方法
Jan 16 Python
Python人工智能之路 之PyAudio 实现录音 自动化交互实现问答
Aug 13 Python
如何用OpenCV -python3实现视频物体追踪
Dec 04 Python
pytorch制作自己的LMDB数据操作示例
Dec 18 Python
pytorch 获取tensor维度信息示例
Jan 03 Python
Pycharm修改python路径过程图解
May 22 Python
音频处理 windows10下python三方库librosa安装教程
Jun 20 Python
python删除文件、清空目录的实现方法
Sep 23 Python
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
Jan 18 #Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
You might like
windows下zendframework项目环境搭建(通过命令行配置)
2012/12/06 PHP
php实现加减法验证码代码
2014/02/14 PHP
php比较两个字符串长度的方法
2015/07/13 PHP
解决更换PHP5.4以上版本后Dedecms后台登录空白问题的方法
2015/10/23 PHP
PHP实现二维数组去重功能示例
2017/01/12 PHP
javascript中的几个运算符
2007/06/29 Javascript
js抽奖实现随机抽奖代码效果
2013/12/02 Javascript
jquery easyui 结合jsp简单展现table数据示例
2014/04/18 Javascript
jQuery验证插件 Validate详解
2014/11/20 Javascript
js简单实现Select互换数据的方法
2015/08/17 Javascript
基于jquery实现智能提示控件intellSeach.js
2016/03/17 Javascript
AngularJS模块详解及示例代码
2016/08/17 Javascript
JS正则截取两个字符串之间及字符串前后内容的方法
2017/01/06 Javascript
JavaScript实现打地鼠小游戏
2020/04/23 Javascript
Node.js Koa2使用JWT进行鉴权的方法示例
2018/08/17 Javascript
详解vue的双向绑定原理及实现
2019/05/05 Javascript
微信小程序和H5页面间相互跳转代码实例
2019/09/19 Javascript
vue.js实现只能输入数字的输入框
2019/10/19 Javascript
python 中文乱码问题深入分析
2011/03/13 Python
python对url格式解析的方法
2015/05/13 Python
django 2.0更新的10条注意事项总结
2018/01/05 Python
对python多线程中互斥锁Threading.Lock的简单应用详解
2019/01/11 Python
pygame实现贪吃蛇游戏(上)
2019/10/29 Python
flask开启多线程的具体方法
2020/08/02 Python
Python利用Faiss库实现ANN近邻搜索的方法详解
2020/08/03 Python
Python importlib模块重载使用方法详解
2020/10/13 Python
python软件测试Jmeter性能测试JDBC Request(结合数据库)的使用详解
2021/01/26 Python
运动会广播稿200字
2014/01/15 职场文书
幼儿园教师国培感言
2014/02/02 职场文书
禁止高声喧哗的标语
2014/06/11 职场文书
教师求职自荐信范文
2015/03/04 职场文书
行政前台岗位职责
2015/04/16 职场文书
个人的事迹材料怎么写
2019/04/24 职场文书
教你解决往mysql数据库中存入汉字报错的方法
2021/05/06 MySQL
如何利用Matlab制作一款真正的拼图小游戏
2021/05/11 Python
python中的sys模块和os模块
2022/03/20 Python