Pytorch 实现数据集自定义读取


Posted in Python onJanuary 18, 2020

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
  #out = np.empty(idx.shape, dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#数据类型转换
  end = time.time()
  return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
  def __init__(self, root, is_train, crop_size=(320,480)):
    self.rgb_mean =(0.485, 0.456, 0.406)
    self.rgb_std = (0.229, 0.224, 0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#创建空列表存文件名称
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
      self.images = f.read().split()
    #数据名称整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,
        "label": label_file,
        "name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整个循环的意思就是将颜色标签映射为单通道的数组索引
    for i, cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引读取每个元素的具体内容
  def __getitem__(self, index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
    #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),
       transforms.ToTensor(),
       transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
    
    #将颜色标签图转为序号标签图
    mylabel=voc_label_indices(croplabel, self.colormap2label)
    
    return cropImage,mylabel
  #返回图像数据长度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
  getimgs, labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor转numpy
  labels=labels[0]#获得批次标签集中的一张标签图像
  labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

  ##将单通道索引标签图片映射回颜色标签图片
  newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
  for i in range(0, 480):
    for j in range(0, 320):
      sele=labels[i][j]#取得坐标点对应像素的值
      newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

  #显示图像和标签
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上这篇Pytorch 实现数据集自定义读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python下函数参数的传递(参数带星号的说明)
Sep 19 Python
Python程序设计入门(5)类的使用简介
Jun 16 Python
python判断图片宽度和高度后删除图片的方法
May 22 Python
Python2.7简单连接与操作MySQL的方法
Apr 27 Python
Python模拟三级菜单效果
Sep 11 Python
Python实现针对给定单链表删除指定节点的方法
Apr 12 Python
matplotlib给子图添加图例的方法
Aug 03 Python
Python实现的微信支付方式总结【三种方式】
Apr 13 Python
如何使用pyinstaller打包32位的exe程序
May 26 Python
python tkinter实现屏保程序
Jul 30 Python
基于keras输出中间层结果的2种实现方式
Jan 24 Python
python定义具名元组实例操作
Feb 28 Python
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
Jan 18 #Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
You might like
PHP通过COM使用ADODB的简单例子
2006/12/31 PHP
php循环输出数据库内容的代码
2008/05/24 PHP
Mysql的Root密码忘记,查看或修改的解决方法(图文介绍)
2013/06/14 PHP
PHP对象相互引用的内存溢出实例分析
2014/08/28 PHP
typecho插件编写教程(二):写一个新插件
2015/05/28 PHP
PHP生成可点击刷新的验证码简单示例
2016/05/13 PHP
浅谈PHP中try{}catch{}的使用方法
2016/12/09 PHP
php利用ob_start()清除输出和选择性输出的方法
2018/01/18 PHP
ThinkPHP5.0多个文件上传后找不到临时文件的修改方法
2018/07/30 PHP
可拖动窗口,附带鼠标控制渐变透明,开启关闭功能
2006/06/26 Javascript
javascript 跳转代码集合
2009/12/03 Javascript
javascript放大镜效果的简单实现
2013/12/09 Javascript
node.js中的fs.readdir方法使用说明
2014/12/17 Javascript
SWFObject基本用法实例分析
2015/07/20 Javascript
JavaScript取得键盘按下方向键是哪个的方法
2015/08/04 Javascript
JS代码实现table数据分页效果
2016/05/26 Javascript
AngularJs定时器$interval 和 $timeout详解
2017/05/25 Javascript
基于JS实现移动端左滑删除功能
2017/07/28 Javascript
JS传播事件、取消事件默认行为、阻止事件传播详解
2017/08/14 Javascript
解决Linux无法正常安装与卸载Node.js的方法
2018/01/19 Javascript
使用vue + less 实现简单换肤功能的示例
2018/02/21 Javascript
jQuery实现鼠标滑动切换图片
2020/05/27 jQuery
[01:24:16]2018DOTA2亚洲邀请赛 4.6 全明星赛
2018/04/10 DOTA
在Python中操作列表之list.extend()方法的使用
2015/05/20 Python
Python调用ctypes使用C函数printf的方法
2017/08/23 Python
Python如何使用paramiko模块连接linux
2020/03/18 Python
利用HTML5+css3+jquery+weui实现仿微信聊天界面功能
2018/01/08 HTML / CSS
EMPHASIS艾斐诗官网:周生生旗下原创精品珠宝品牌
2020/12/17 全球购物
金蝶的一道SQL笔试题
2012/12/18 面试题
面试求职的个人自我评价
2013/11/16 职场文书
酒吧总经理岗位职责
2013/12/10 职场文书
优秀党员主要事迹
2014/01/19 职场文书
初中优秀学生评语
2014/12/29 职场文书
幼儿园母亲节活动总结
2015/02/10 职场文书
2016年安全月活动总结
2016/04/06 职场文书
python+opencv实现目标跟踪过程
2022/06/21 Python