Pytorch 实现数据集自定义读取


Posted in Python onJanuary 18, 2020

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
  #out = np.empty(idx.shape, dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#数据类型转换
  end = time.time()
  return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
  def __init__(self, root, is_train, crop_size=(320,480)):
    self.rgb_mean =(0.485, 0.456, 0.406)
    self.rgb_std = (0.229, 0.224, 0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#创建空列表存文件名称
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
      self.images = f.read().split()
    #数据名称整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,
        "label": label_file,
        "name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整个循环的意思就是将颜色标签映射为单通道的数组索引
    for i, cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引读取每个元素的具体内容
  def __getitem__(self, index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
    #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),
       transforms.ToTensor(),
       transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
    
    #将颜色标签图转为序号标签图
    mylabel=voc_label_indices(croplabel, self.colormap2label)
    
    return cropImage,mylabel
  #返回图像数据长度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
  getimgs, labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor转numpy
  labels=labels[0]#获得批次标签集中的一张标签图像
  labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

  ##将单通道索引标签图片映射回颜色标签图片
  newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
  for i in range(0, 480):
    for j in range(0, 320):
      sele=labels[i][j]#取得坐标点对应像素的值
      newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

  #显示图像和标签
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上这篇Pytorch 实现数据集自定义读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 正则表达式(转义问题)
Dec 15 Python
Python实现去除代码前行号的方法
Mar 10 Python
Python如何读取MySQL数据库表数据
Mar 11 Python
Python实现压缩和解压缩ZIP文件的方法分析
Sep 28 Python
python读取各种文件数据方法解析
Dec 29 Python
Python实现截取PDF文件中的几页代码实例
Mar 11 Python
详解Python的循环结构知识点
May 20 Python
浅谈Django中的QueryDict元素为数组的坑
Mar 31 Python
python操作微信自动发消息的实现(微信聊天机器人)
Jul 14 Python
10款最佳Python开发工具推荐,每一款都是神器
Oct 15 Python
python 机器学习的标准化、归一化、正则化、离散化和白化
Apr 16 Python
python 多态 协议 鸭子类型详解
Nov 27 Python
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
Jan 18 #Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
You might like
海河写的 Discuz论坛帖子调用js的php代码
2007/08/23 PHP
简单的PHP缓存设计实现代码
2011/09/30 PHP
PHP导出EXCEL快速开发指南--PHPEXCEL的使用详解
2013/06/03 PHP
让innerHTML的脚本也可以运行起来
2006/07/01 Javascript
JS IE和FF兼容性问题汇总
2009/02/09 Javascript
Jquery ThickBox插件使用心得(不建议使用)
2010/09/08 Javascript
JS实现将人民币金额转换为大写的示例代码
2014/02/13 Javascript
完美兼容各大浏览器的jQuery仿新浪图文淡入淡出间歇滚动特效
2014/11/12 Javascript
JS去除iframe滚动条的方法
2015/04/01 Javascript
JS实现点击按钮后框架内载入不同网页的方法
2015/05/05 Javascript
基于javascript实现页面加载loading效果
2020/09/15 Javascript
Bootstrap媒体对象的实现
2016/05/01 Javascript
JavaScript自学笔记(必看篇)
2016/06/23 Javascript
几种响应式文字详解
2017/05/19 Javascript
JavaScript创建对象的七种方式(推荐)
2017/06/26 Javascript
史上最全JavaScript常用的简写技巧(推荐)
2017/08/17 Javascript
Web开发使用Angular实现用户密码强度判别的方法
2017/09/27 Javascript
JavaScript实现一个简易的计算器实例代码
2018/05/10 Javascript
详解Ant Design of React的安装和使用方法
2018/12/27 Javascript
vue用BMap百度地图实现即时搜索功能
2019/09/26 Javascript
基于JavaScript实现单例模式
2019/10/30 Javascript
[08:42]DOTA2每周TOP10 精彩击杀集锦vol.2
2014/06/25 DOTA
[07:26]2015国际邀请赛第二日TOP10集锦
2015/08/06 DOTA
[55:25]VGJ.T vs Optic Supermajor小组赛D组 BO3 第三场 6.3
2018/06/04 DOTA
python生成特定分布数的实例
2019/12/05 Python
基于Python共轭梯度法与最速下降法之间的对比
2020/04/02 Python
Pyside2中嵌入Matplotlib的绘图的实现
2021/02/22 Python
使用Python webdriver图书馆抢座自动预约的正确方法
2021/03/04 Python
Web页面中八种创建多列等高(等高列布局)的实现技术
2012/12/24 HTML / CSS
Reebonz中国官网:新加坡奢侈品购物网站
2017/03/17 全球购物
名人珠宝设计师:Melinda Maria Jewelry
2019/03/06 全球购物
俄罗斯大型在线书店:Читай-город
2019/10/10 全球购物
学习经验演讲稿
2014/05/10 职场文书
学雷锋标兵事迹材料
2014/08/18 职场文书
优秀高中学生评语
2014/12/30 职场文书
2015年政务公开工作总结
2015/05/19 职场文书