PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python多线程编程(三):threading.Thread类的重要函数和方法
Apr 05 Python
Python while 循环使用的简单实例
Jun 08 Python
python 排序算法总结及实例详解
Sep 28 Python
Python实现判断给定列表是否有重复元素的方法
Apr 11 Python
ipython和python区别详解
Jun 26 Python
python使用writerows写csv文件产生多余空行的处理方法
Aug 01 Python
Django中自定义模型管理器(Manager)及方法
Sep 23 Python
django在保存图像的同时压缩图像示例代码详解
Feb 11 Python
python str字符串转uuid实例
Mar 03 Python
Python命名空间namespace及作用域原理解析
Jun 05 Python
Django如何实现密码错误报错提醒
Sep 04 Python
Elasticsearch 数据类型及管理
Apr 19 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
thinkphp3.0 模板中函数的使用
2012/11/13 PHP
PHP数据库链接类(PDO+Access)实例分享
2013/12/05 PHP
WordPress中"无法将上传的文件移动至"错误的解决方法
2015/07/01 PHP
PHP和C#可共用的可逆加密算法详解
2015/10/26 PHP
PHP上传文件参考配置大文件上传
2015/12/16 PHP
Thinkphp通过一个入口文件如何区分移动端和PC端
2017/04/18 PHP
实例讲解php实现多线程
2019/01/27 PHP
php反序列化长度变化尾部字符串逃逸(0CTF-2016-piapiapia)
2020/02/15 PHP
JQery 渐变图片导航效果代码 漂亮
2010/01/01 Javascript
JS 面向对象之神奇的prototype
2011/02/26 Javascript
js Dialog 实践分享
2012/10/22 Javascript
JavaScript的事件绑定(方便不支持js的时候)
2013/10/01 Javascript
JqueryMobile动态生成listView并实现刷新的两种方法
2014/03/05 Javascript
jQuery判断当前点击的是第几个li的代码
2014/09/26 Javascript
node.js中的http.response.addTrailers方法使用说明
2014/12/14 Javascript
浅谈JavaScript中的属性:如何遍历属性
2017/09/14 Javascript
React Native 真机断点调试+跨域资源加载出错问题的解决方法
2018/01/18 Javascript
jQuery实现form表单序列化转换为json对象功能示例
2018/05/23 jQuery
vue.js父子组件通信动态绑定的实例
2018/09/28 Javascript
JS实现二维数组元素的排列组合运算简单示例
2019/01/28 Javascript
解决Vue调用springboot接口403跨域问题
2019/09/02 Javascript
解决vue项目获取dom元素宽高总是不准确问题
2020/07/29 Javascript
Python利用Beautiful Soup模块创建对象详解
2017/03/27 Python
Python语言实现百度语音识别API的使用实例
2017/12/13 Python
Django使用Celery异步任务队列的使用
2018/03/13 Python
Python实现替换文件中指定内容的方法
2018/03/19 Python
python进程间通信Queue工作过程详解
2019/11/01 Python
TensorFlow:将ckpt文件固化成pb文件教程
2020/02/11 Python
keras 使用Lambda 快速新建层 添加多个参数操作
2020/06/10 Python
团队经理竞聘书
2014/03/31 职场文书
债务追讨授权委托书范本
2014/10/16 职场文书
贵阳市党的群众路线教育实践活动党(工)委领导班子整改方案
2014/10/26 职场文书
辛德勒的名单观后感
2015/06/03 职场文书
css3实现背景图片颜色修改的多种方式
2021/04/13 HTML / CSS
对PyTorch中inplace字段的全面理解
2021/05/22 Python
Python-OpenCV实现图像缺陷检测的实例
2021/06/11 Python