Pytorch 实现数据集自定义读取


Posted in Python onJanuary 18, 2020

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
  #out = np.empty(idx.shape, dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#数据类型转换
  end = time.time()
  return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
  def __init__(self, root, is_train, crop_size=(320,480)):
    self.rgb_mean =(0.485, 0.456, 0.406)
    self.rgb_std = (0.229, 0.224, 0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#创建空列表存文件名称
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
      self.images = f.read().split()
    #数据名称整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,
        "label": label_file,
        "name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整个循环的意思就是将颜色标签映射为单通道的数组索引
    for i, cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引读取每个元素的具体内容
  def __getitem__(self, index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
    #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),
       transforms.ToTensor(),
       transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
    
    #将颜色标签图转为序号标签图
    mylabel=voc_label_indices(croplabel, self.colormap2label)
    
    return cropImage,mylabel
  #返回图像数据长度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
  getimgs, labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor转numpy
  labels=labels[0]#获得批次标签集中的一张标签图像
  labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

  ##将单通道索引标签图片映射回颜色标签图片
  newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
  for i in range(0, 480):
    for j in range(0, 320):
      sele=labels[i][j]#取得坐标点对应像素的值
      newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

  #显示图像和标签
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上这篇Pytorch 实现数据集自定义读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中for循环的使用方法
May 14 Python
Django的数据模型访问多对多键值的方法
Jul 21 Python
Windows系统下使用flup搭建Nginx和Python环境的方法
Dec 25 Python
python基础教程之匿名函数lambda
Jan 17 Python
itchat接口使用示例
Oct 23 Python
Python在for循环中更改list值的方法【推荐】
Aug 17 Python
Python单元测试unittest的具体使用示例
Dec 17 Python
与Django结合利用模型对上传图片预测的实例详解
Aug 07 Python
Python 元组拆包示例(Tuple Unpacking)
Dec 24 Python
Python全面分析系统的时域特性和频率域特性
Feb 26 Python
在python中修改.properties文件的操作
Apr 08 Python
python3爬虫中多线程的优势总结
Nov 24 Python
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
Jan 18 #Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
You might like
php 按指定元素值去除数组元素的实现方法
2011/11/04 PHP
php 在windows下配置虚拟目录的方法介绍
2013/06/26 PHP
PHP用strstr()函数阻止垃圾评论(通过判断a标记)
2013/09/28 PHP
php mysqli查询语句返回值类型实例分析
2016/06/29 PHP
Zend Framework常用校验器详解
2016/12/09 PHP
php 常用的系统函数
2017/02/07 PHP
PHP+Session防止表单重复提交的解决方法
2018/04/09 PHP
jquery BS,dialog控件自适应大小
2009/07/06 Javascript
javascript 流畅动画实现原理
2009/09/08 Javascript
JSON 学习之JSON in JavaScript详细使用说明
2010/02/23 Javascript
js下拉菜单语言选项简单实现
2013/09/23 Javascript
jquery基础教程之deferred对象使用方法
2014/01/22 Javascript
javascript操作excel生成报表全攻略
2014/05/04 Javascript
JavaScript 获取任一float型小数点后两位的小数
2014/06/30 Javascript
nodejs URL模块操作URL相关方法介绍
2015/03/03 NodeJs
jQuery 实现评论等级好评差评特效
2016/05/06 Javascript
Bootstrap开发实战之第一次接触Bootstrap
2016/06/02 Javascript
小程序二次贝塞尔曲线实现购物车商品曲线飞入效果
2019/01/07 Javascript
利用Promise自定义一个GET请求的函数示例代码
2019/03/20 Javascript
详解用js代码触发dom事件的实现方案
2020/06/10 Javascript
Node.js fs模块原理及常见用途
2020/10/22 Javascript
[01:08:32]DOTA2-DPC中国联赛 正赛 DLG vs PHOENIX BO3 第二场 1月18日
2021/03/11 DOTA
pandas 将list切分后存入DataFrame中的实例
2018/07/03 Python
基于python实现聊天室程序
2018/07/27 Python
Python安装selenium包详细过程
2019/07/23 Python
python手机号前7位归属地爬虫代码实例
2020/03/31 Python
python3.5的包存放的具体路径
2020/08/16 Python
MoviePy常用剪辑类及Python视频剪辑自动化
2020/12/18 Python
YOOX美国官方网站:全球著名的多品牌时尚网络概念店
2016/09/11 全球购物
教师实习的自我鉴定
2013/10/26 职场文书
个人担保书格式范文
2014/05/12 职场文书
2015年监理工作总结范文
2015/04/07 职场文书
社区义诊通知
2015/04/24 职场文书
幼儿园班级工作总结2015
2015/05/25 职场文书
2016拓展训练心得体会范文
2016/01/12 职场文书
uwsgi+nginx代理Django无法访问静态资源的解决
2021/05/10 Servers