Pytorch 实现数据集自定义读取


Posted in Python onJanuary 18, 2020

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
  #out = np.empty(idx.shape, dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#数据类型转换
  end = time.time()
  return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
  def __init__(self, root, is_train, crop_size=(320,480)):
    self.rgb_mean =(0.485, 0.456, 0.406)
    self.rgb_std = (0.229, 0.224, 0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#创建空列表存文件名称
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
      self.images = f.read().split()
    #数据名称整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,
        "label": label_file,
        "name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整个循环的意思就是将颜色标签映射为单通道的数组索引
    for i, cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引读取每个元素的具体内容
  def __getitem__(self, index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
    #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),
       transforms.ToTensor(),
       transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
    
    #将颜色标签图转为序号标签图
    mylabel=voc_label_indices(croplabel, self.colormap2label)
    
    return cropImage,mylabel
  #返回图像数据长度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
  getimgs, labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor转numpy
  labels=labels[0]#获得批次标签集中的一张标签图像
  labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

  ##将单通道索引标签图片映射回颜色标签图片
  newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
  for i in range(0, 480):
    for j in range(0, 320):
      sele=labels[i][j]#取得坐标点对应像素的值
      newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

  #显示图像和标签
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上这篇Pytorch 实现数据集自定义读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pygame学习笔记(6):完成一个简单的游戏
Apr 15 Python
Python selenium如何设置等待时间
Sep 15 Python
python3基于OpenCV实现证件照背景替换
Jul 18 Python
Python过滤txt文件内重复内容的方法
Oct 21 Python
利用Python如何实现一个小说网站雏形
Nov 23 Python
python 获取微信好友列表的方法(微信web)
Feb 21 Python
关于Python作用域自学总结
Jun 10 Python
pytorch 实现tensor与numpy数组转换
Dec 27 Python
Python中url标签使用知识点总结
Jan 16 Python
如何Tkinter模块编写Python图形界面
Oct 14 Python
基于Python爬取素材网站音频文件
Oct 21 Python
使用Python下载抖音各大V视频的思路详解
Feb 06 Python
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
Jan 18 #Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
You might like
全国FM电台频率大全 - 28 甘肃省
2020/03/11 无线电
php表单提交与$_POST实例分析
2015/01/26 PHP
PHP常见的几种攻击方式实例小结
2019/04/29 PHP
PHP检查文件是否存在,不存在自动创建及读取文件内容操作示例
2020/01/23 PHP
PHP标准库 (SPL)――Countable用法示例
2020/06/05 PHP
基于jquery的无限级联下拉框js插件
2011/10/29 Javascript
jsvascript图像处理—(计算机视觉应用)图像金字塔
2013/01/15 Javascript
js 获取(接收)地址栏参数值的方法
2013/04/01 Javascript
14个有用的Jquery技巧分享
2015/01/08 Javascript
JavaScript计算器网页版实现代码分享
2016/07/15 Javascript
Node.js配合node-http-proxy解决本地开发ajax跨域问题
2016/08/31 Javascript
微信小程序--onShareAppMessage分享参数用处(页面分享)
2017/04/18 Javascript
JavaScript中常见的八个陷阱总结
2017/06/28 Javascript
浅析vue中常见循环遍历指令的使用 v-for
2018/04/18 Javascript
vue使用canvas实现移动端手写签名
2020/09/22 Javascript
vantUI 获得piker选中值的自定义ID操作
2020/11/04 Javascript
python将人民币转换大写的脚本代码
2013/02/10 Python
python中__call__内置函数用法实例
2015/06/04 Python
python实现基本进制转换的方法
2015/07/11 Python
Python实现自定义函数的5种常见形式分析
2018/06/16 Python
Python爬虫小技巧之伪造随机的User-Agent
2018/09/13 Python
python实现旋转和水平翻转的方法
2018/10/25 Python
django多文件上传,form提交,多对多外键保存的实例
2019/08/06 Python
Free People中国官网:波西米亚风格女装服饰
2016/08/30 全球购物
美国旅游网站:Tours4Fun
2017/02/17 全球购物
MSC邮轮官方网站:加勒比海、地中海和世界各地的假期
2018/08/27 全球购物
美国香薰蜡烛品牌:PADDYWAX
2018/10/06 全球购物
高中生期末评语
2014/01/28 职场文书
小学校园活动策划
2014/01/30 职场文书
你的创业计划书怎样才能打动风投
2014/02/06 职场文书
大学生活动策划方案
2014/02/10 职场文书
趣味比赛活动方案
2014/02/15 职场文书
班级旅游计划书
2014/05/03 职场文书
三八妇女节演讲稿
2014/05/27 职场文书
小学教师师德师风个人整改措施
2014/09/18 职场文书
详解Python中下划线的5种含义
2021/07/15 Python