Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作


Posted in Python onMarch 03, 2021

【源码GitHub地址】:点击进入

1. 问题描述

之前写了一篇关于《pytorch Dataset, DataLoader产生自定义的训练数据》的博客,但存在一个问题,我们不能在Dataset做一些数据清理,如果我们传递给Dataset数据,本身存在问题,那么迭代过程肯定出错的。

比如我把很多图片路径都传递给Dataset,如果图片路径都是正确的,且图片都存在也没有损坏,那显然运行是没有问题的;

但倘若传递给Dataset的图片路径有些图片是不存在,这时你通过Dataset读取图片数据,然后再迭代返回,就会出现类似如下的错误:

File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 68, in <listcomp> return [default_collate(samples) for samples in transposed]

File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 70, in default_collate

raise TypeError((error_msg_fmt.format(type(batch[0])))) TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'NoneType'>

2. 一般的解决方法

一般的解决方法也很简单粗暴,就是在传递数据给Dataset前,就做数据清理,把不存在的图片,损坏的数据都提前清理掉。

是的,这个是最简单粗暴的。

3. 另一种解决方法:自定义返回数据的规则:collate_fn()校对函数

我们希望不管传递什么处理给Dataset,Dataset都进行处理,如果不存在或者异常,就返回None,而在DataLoader时,对于不存为None的数据,都去除掉。

这样就保证在迭代过程中,DataLoader获得batch数据都是正确的。

比如读取batch_size=5的图片数据,如果其中有1个(或者多个)图片是不存在,那么返回的batch应该把不存在的数据过滤掉,即返回5-1=4大小的batch的数据。

是的,我要实现的就是这个功能:返回的batch数据会自定清理掉不合法的数据。

3.1 Pytorch数据处理函数:Dataset和 DataLoader

Pytorch有两个数据处理函数:Dataset和 DataLoader

from torch.utils.data import Dataset, DataLoader

其中Dataset用于定义数据的读取和预处理操作,而DataLoader用于加载并产生批训练数据。

torch.utils.data.DataLoader参数说明:

DataLoader(object)可用参数:

1、dataset(Dataset) 传入的数据集

2、batch_size(int, optional) 每个batch有多少个样本

3、shuffle(bool, optional) 在每个epoch开始的时候,对数据进行重新排序

4、sampler(Sampler, optional) 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

5、batch_sampler(Sampler, optional) 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

6、num_workers (int, optional) 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

7、collate_fn (callable, optional) 将一个list的sample组成一个mini-batch的函数

8、pin_memory (bool, optional) 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

9、drop_last (bool, optional) 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了。 如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

10、timeout(numeric, optional) 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

11、worker_init_fn (callable, optional) 每个worker初始化函数 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

我们要用到的是collate_fn()回调函数

3.2 自定义collate_fn()函数:

torch.utils.data.DataLoader的collate_fn()用于设置batch数据拼接方式,默认是default_collate函数,但当batch中含有None等数据时,默认的default_collate校队方法会出现错误。因此,我们需要自定义collate_fn()函数:

方法也很简单:只需在原来的default_collate函数中添加下面几句代码:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了。

# 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 if isinstance(batch, list):
 batch = [(image, image_id) for (image, image_id) in batch if image is not None]
 if batch==[]:
 return (None,None)

dataset_collate.py:

# -*-coding: utf-8 -*-
"""
 @Project: pytorch-learning-tutorials
 @File : dataset_collate.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2019-06-07 17:09:13
"""
 
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import torch
import re
from torch._six import container_abcs, string_classes, int_classes 
_use_shared_memory = False
r"""Whether to use shared memory in default_collate"""
 
np_str_obj_array_pattern = re.compile(r'[SaUO]')
 
error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
 
numpy_type_map = {
 'float64': torch.DoubleTensor,
 'float32': torch.FloatTensor,
 'float16': torch.HalfTensor,
 'int64': torch.LongTensor,
 'int32': torch.IntTensor,
 'int16': torch.ShortTensor,
 'int8': torch.CharTensor,
 'uint8': torch.ByteTensor,
}
 
def collate_fn(batch):
 '''
 collate_fn (callable, optional): merges a list of samples to form a mini-batch.
 该函数参考touch的default_collate函数,也是DataLoader的默认的校对方法,当batch中含有None等数据时,
 默认的default_collate校队方法会出现错误
 一种的解决方法是:
 判断batch中image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 :param batch:
 :return:
 '''
 r"""Puts each data field into a tensor with outer dimension batch size"""
 # 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 if isinstance(batch, list):
 batch = [(image, image_id) for (image, image_id) in batch if image is not None]
 if batch==[]:
 return (None,None)
 
 elem_type = type(batch[0])
 if isinstance(batch[0], torch.Tensor):
 out = None
 if _use_shared_memory:
  # If we're in a background process, concatenate directly into a
  # shared memory tensor to avoid an extra copy
  numel = sum([x.numel() for x in batch])
  storage = batch[0].storage()._new_shared(numel)
  out = batch[0].new(storage)
 return torch.stack(batch, 0, out=out)
 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
  and elem_type.__name__ != 'string_':
 elem = batch[0]
 if elem_type.__name__ == 'ndarray':
  # array of string classes and object
  if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  raise TypeError(error_msg_fmt.format(elem.dtype))
 
  return collate_fn([torch.from_numpy(b) for b in batch])
 if elem.shape == (): # scalars
  py_type = float if elem.dtype.name.startswith('float') else int
  return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
 elif isinstance(batch[0], float):
 return torch.tensor(batch, dtype=torch.float64)
 elif isinstance(batch[0], int_classes):
 return torch.tensor(batch)
 elif isinstance(batch[0], string_classes):
 return batch
 elif isinstance(batch[0], container_abcs.Mapping):
 return {key: collate_fn([d[key] for d in batch]) for key in batch[0]}
 elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
 return type(batch[0])(*(collate_fn(samples) for samples in zip(*batch)))
 elif isinstance(batch[0], container_abcs.Sequence):
 transposed = zip(*batch)#ok
 return [collate_fn(samples) for samples in transposed]
 
 raise TypeError((error_msg_fmt.format(type(batch[0]))))

测试方法:

# -*-coding: utf-8 -*-
"""
 @Project: pytorch-learning-tutorials
 @File : dataset.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2019-03-07 18:45:06
"""
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from utils import dataset_collate
import os
import cv2
from PIL import Image
def read_image(path,mode='RGB'):
 '''
 :param path:
 :param mode: RGB or L
 :return:
 '''
 return Image.open(path).convert(mode)
 
class TorchDataset(Dataset):
 def __init__(self, image_id_list, image_dir, resize_height=256, resize_width=256, repeat=1, transform=None):
 '''
 :param filename: 数据文件TXT:格式:imge_name.jpg label1_id labe2_id
 :param image_dir: 图片路径:image_dir+imge_name.jpg构成图片的完整路径
 :param resize_height 为None时,不进行缩放
 :param resize_width 为None时,不进行缩放,
    PS:当参数resize_height或resize_width其中一个为None时,可实现等比例缩放
 :param repeat: 所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环<sys.maxsize
 :param transform:预处理
 '''
 self.image_dir = image_dir
 self.image_id_list=image_id_list
 self.len = len(image_id_list)
 self.repeat = repeat
 self.resize_height = resize_height
 self.resize_width = resize_width
 self.transform= transform
 
 def __getitem__(self, i):
 index = i % self.len
 # print("i={},index={}".format(i, index))
 image_id = self.image_id_list[index]
 image_path = os.path.join(self.image_dir, image_id)
 img = self.load_data(image_path)
 
 if img is None:
  return None,image_id
 img = self.data_preproccess(img)
 return img,image_id
 
 def __len__(self):
 if self.repeat == None:
  data_len = 10000000
 else:
  data_len = len(self.image_id_list) * self.repeat
 return data_len
 
 def load_data(self, path):
 '''
 加载数据
 :param path:
 :param resize_height:
 :param resize_width:
 :param normalization: 是否归一化
 :return:
 '''
 try:
  image = read_image(path)
 except Exception as e:
  image=None
  print(e)
 # image = image_processing.read_image(path)#用opencv读取图像
 return image
 
 def data_preproccess(self, data):
 '''
 数据预处理
 :param data:
 :return:
 '''
 if self.transform is not None:
  data = self.transform(data)
 return data
 
if __name__=='__main__':
 
 resize_height = 224
 resize_width = 224
 image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"]
 image_dir="../dataset/test_images/images"
 # 相关预处理的初始化
 '''class torchvision.transforms.ToTensor把shape=(H,W,C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray数据
 # 转换成shape=(C,H,W)的像素数据,并且被归一化到[0.0, 1.0]的torch.FloatTensor类型。
 '''
 train_transform = transforms.Compose([
 transforms.Resize(size=(resize_height, resize_width)),
 # transforms.RandomHorizontalFlip(),#随机翻转图像
 transforms.RandomCrop(size=(resize_height, resize_width), padding=4), # 随机裁剪
 transforms.ToTensor(), # 吧shape=(H,W,C)->换成shape=(C,H,W),并且归一化到[0.0, 1.0]的torch.FloatTensor类型
 # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#给定均值(R,G,B) 方差(R,G,B),将会把Tensor正则化
 ])
 
 epoch_num=2 #总样本循环次数
 batch_size=5 #训练时的一组数据的大小
 train_data_nums=10
 max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #总迭代次数
 
 train_data = TorchDataset(image_id_list=image_id_list,
    image_dir=image_dir,
    resize_height=resize_height,
    resize_width=resize_width,
    repeat=1,
    transform=train_transform)
 # 使用默认的default_collate会报错
 # train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
 # 使用自定义的collate_fn
 train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False, collate_fn=dataset_collate.collate_fn)
 
 
 # [1]使用epoch方法迭代,TorchDataset的参数repeat=1
 for epoch in range(epoch_num):
 for step,(batch_image, batch_label) in enumerate(train_loader):
  if batch_image is None and batch_label is None:
  print("batch_image:{},batch_label:{}".format(batch_image, batch_label))
  continue
  image=batch_image[0,:]
  image=image.numpy()#image=np.array(image)
  image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  cv2.imshow("image",image)
  cv2.waitKey(2000)
  print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
  # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

输出结果说明:

batch_size=5,输入图片列表image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"] ,其中"ddd.jpg","111.jpg"是不存在的,resize_width=224,正常情况下返回的数据应该是torch.Size([5, 3, 224, 224]),但由于"ddd.jpg","111.jpg"不存在,被过滤掉了,所以第一个batch的维度变为torch.Size([3, 3, 224, 224])

[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
深入讲解Python中面向对象编程的相关知识
May 25 Python
栈和队列数据结构的基本概念及其相关的Python实现
Aug 24 Python
Python读取txt文件数据的方法(用于接口自动化参数化数据)
Jun 27 Python
Python装饰器模式定义与用法分析
Aug 06 Python
Python OpenCV读取png图像转成jpg图像存储的方法
Oct 28 Python
python 找出list中最大或者最小几个数的索引方法
Oct 30 Python
python中使用 xlwt 操作excel的常见方法与问题
Jan 13 Python
python钉钉机器人运维脚本监控实例
Feb 20 Python
python操作gitlab API过程解析
Dec 27 Python
sklearn+python:线性回归案例
Feb 24 Python
使用Python三角函数公式计算三角形的夹角案例
Apr 15 Python
python中使用np.delete()的实例方法
Feb 01 Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
解决pytorch 数据类型报错的问题
Mar 03 #Python
python反编译教程之2048小游戏实例
Mar 03 #Python
python 如何读、写、解析CSV文件
Mar 03 #Python
聊聊python在linux下与windows下导入模块的区别说明
Mar 03 #Python
python 递归相关知识总结
Mar 03 #Python
You might like
php Rename 更改文件、文件夹名称
2011/05/24 PHP
一个php短网址的生成代码(仿微博短网址)
2014/05/07 PHP
PHP常用的排序和查找算法
2015/08/06 PHP
php上传大文件设置方法
2016/04/14 PHP
根据地区不同显示时间的javascript代码
2007/08/13 Javascript
JavaScript 无符号右移运算符
2009/04/17 Javascript
十分钟打造AutoComplete自动完成效果代码
2009/12/26 Javascript
LazyLoad 延迟加载(按需加载)
2010/05/31 Javascript
js 为label标签和div标签赋值的方法
2013/08/08 Javascript
手机端网页点击链接触发自动拨打或保存电话的示例代码
2014/08/15 Javascript
AngularJS 实现按需异步加载实例代码
2015/10/18 Javascript
探究JavaScript函数式编程的乐趣
2015/12/14 Javascript
js如何判断输入字符串长度
2015/12/16 Javascript
JS实现一次性弹窗的方法【刷新后不弹出】
2016/12/26 Javascript
基于Vue实现微信小程序的图文编辑器
2018/07/25 Javascript
详解Nodejs get获取远程服务器接口数据
2019/03/26 NodeJs
jQuery层叠选择器用法实例分析
2019/06/28 jQuery
Nodejs 识别图片类型的方法
2019/08/15 NodeJs
如何实现一个简易版的vuex持久化工具
2019/09/11 Javascript
vue滚动插件better-scroll使用详解
2019/10/18 Javascript
Vue props中Object和Array设置默认值操作
2020/07/30 Javascript
[02:36]DOTA2英雄基础教程 帕格纳
2014/01/20 DOTA
Python利用flask sqlalchemy实现分页效果
2020/08/02 Python
使用Python快速搭建HTTP服务和文件共享服务的实例讲解
2018/06/04 Python
python实现视频分帧效果
2019/05/31 Python
详解基于python-django框架的支付宝支付案例
2019/09/23 Python
Python模块相关知识点小结
2020/03/09 Python
python爬虫用request库处理cookie的实例讲解
2021/02/20 Python
主持人演讲稿范文
2013/12/28 职场文书
家长会演讲稿范文
2014/01/10 职场文书
小学六年级学生评语
2014/04/22 职场文书
授权收款委托书范本
2014/10/10 职场文书
护士求职简历自我评价
2015/03/10 职场文书
辩论赛新闻稿
2015/07/17 职场文书
党员反邪教心得体会
2016/01/15 职场文书
求职信如何撰写?
2019/05/22 职场文书