Pytorch 实现数据集自定义读取


Posted in Python onJanuary 18, 2020

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
  #out = np.empty(idx.shape, dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#数据类型转换
  end = time.time()
  return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
  def __init__(self, root, is_train, crop_size=(320,480)):
    self.rgb_mean =(0.485, 0.456, 0.406)
    self.rgb_std = (0.229, 0.224, 0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#创建空列表存文件名称
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
      self.images = f.read().split()
    #数据名称整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,
        "label": label_file,
        "name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整个循环的意思就是将颜色标签映射为单通道的数组索引
    for i, cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引读取每个元素的具体内容
  def __getitem__(self, index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
    #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),
       transforms.ToTensor(),
       transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
    
    #将颜色标签图转为序号标签图
    mylabel=voc_label_indices(croplabel, self.colormap2label)
    
    return cropImage,mylabel
  #返回图像数据长度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
  getimgs, labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor转numpy
  labels=labels[0]#获得批次标签集中的一张标签图像
  labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

  ##将单通道索引标签图片映射回颜色标签图片
  newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
  for i in range(0, 480):
    for j in range(0, 320):
      sele=labels[i][j]#取得坐标点对应像素的值
      newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

  #显示图像和标签
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上这篇Pytorch 实现数据集自定义读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python标准日志模块logging的使用方法
Nov 01 Python
Python中使用logging模块打印log日志详解
Apr 05 Python
Python将图片批量从png格式转换至WebP格式
Aug 22 Python
详解python基础之while循环及if判断
Aug 24 Python
python爬虫之BeautifulSoup 使用select方法详解
Oct 23 Python
Linux(Redhat)安装python3.6虚拟环境(推荐)
May 05 Python
Python操作word常见方法示例【win32com与docx模块】
Jul 17 Python
用python做游戏的细节详解
Jun 25 Python
解决Pycharm 包已经下载,但是运行代码提示找不到模块的问题
Aug 31 Python
解决python replace函数替换无效问题
Jan 18 Python
关于python pygame游戏进行声音添加的技巧
Oct 24 Python
Python字符串的转义字符
Apr 07 Python
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
Jan 18 #Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
You might like
php面向对象全攻略 (一) 面向对象基础知识
2009/09/30 PHP
第五章 php数组操作
2011/12/30 PHP
PHP获取服务器端信息的方法
2014/11/28 PHP
php修改上传图片尺寸的方法
2015/04/14 PHP
php基于Fleaphp框架实现cvs数据导入MySQL的方法
2016/02/23 PHP
php版微信数据统计接口用法示例
2016/10/12 PHP
不用构造函数(Constructor)new关键字也能实现JavaScript的面向对象
2013/01/11 Javascript
jQuery trigger()方法用法介绍
2015/01/13 Javascript
Javascript中的arguments与重载介绍
2015/03/15 Javascript
nodejs创建web服务器之hello world程序
2015/08/20 NodeJs
浅谈jquery的map()和each()方法
2016/06/12 Javascript
JS实现的图片预览插件与用法示例【不上传图片】
2016/11/25 Javascript
微信小程序 PHP后端form表单提交实例详解
2017/01/12 Javascript
node.js实现微信JS-API封装接口的示例代码
2017/09/06 Javascript
Vue引用第三方datepicker插件无法监听datepicker输入框的值的解决
2018/01/27 Javascript
vue 表单输入格式化中文输入法异常问题
2018/05/30 Javascript
Vue模拟数据,实现路由进入商品详情页面的示例
2018/08/31 Javascript
Vue面试题及Vue知识点整理
2018/10/07 Javascript
如何利用ES6进行Promise封装总结
2019/02/11 Javascript
Vue Autocomplete 自动完成功能简单示例
2019/05/25 Javascript
Vue记住滚动条和实现下拉加载的完美方法
2020/07/31 Javascript
[02:44]DOTA2英雄基础教程 钢背兽
2013/12/19 DOTA
Python Cookie 读取和保存方法
2018/12/28 Python
Django urls.py重构及参数传递详解
2019/07/23 Python
Django 开发调试工具 Django-debug-toolbar使用详解
2019/07/23 Python
python多线程和多进程关系详解
2020/12/14 Python
工程部经理岗位职责
2013/12/08 职场文书
毕业生实习鉴定
2013/12/11 职场文书
网络研修随笔感言
2014/02/17 职场文书
物流管理专业毕业生自荐信
2014/03/04 职场文书
运动会广播稿100字
2014/09/14 职场文书
机关单位工作失职检讨书
2014/11/20 职场文书
学生自我评语
2015/01/04 职场文书
放牛班的春天观后感
2015/06/01 职场文书
一条 SQL 语句执行过程
2022/03/17 MySQL
详解Golang如何实现支持随机删除元素的堆
2022/09/23 Python