PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
使用Python求解最大公约数的实现方法
Aug 20 Python
Python基于tkinter模块实现的改名小工具示例
Jul 27 Python
Python实现批量压缩图片
Jan 25 Python
使用numpy和PIL进行简单的图像处理方法
Jul 02 Python
python实现简单名片管理系统
Nov 30 Python
django创建超级用户过程解析
Sep 18 Python
基于Python实现大文件分割和命名脚本过程解析
Sep 29 Python
python pyinstaller打包exe报错的解决方法
Nov 02 Python
TensorFLow 数学运算的示例代码
Apr 21 Python
pycharm 激活码及使用方式的详细教程
May 12 Python
python Matplotlib数据可视化(1):简单入门
Sep 30 Python
Pycharm编辑器功能之代码折叠效果的实现代码
Oct 15 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
php对图像的各种处理函数代码小结
2013/07/08 PHP
在Windows系统下使用PHP生成Word文档的教程
2015/07/03 PHP
PHP中函数gzuncompress无法使用的解决方法
2017/03/02 PHP
Yii2设置默认控制器的两种方法
2017/05/19 PHP
限制复选框的最大可选数
2006/07/01 Javascript
js字符编码函数区别分析
2008/06/05 Javascript
Wordpress ThickBox 点击图片显示下一张图的修改方法
2010/12/11 Javascript
jquery mobile changepage的三种传参方法介绍
2013/09/13 Javascript
javascript运行机制之this详细介绍
2014/02/07 Javascript
javascript属性访问表达式用法分析
2015/04/25 Javascript
javascript原生ajax写法分享
2016/04/10 Javascript
jQuery each函数源码分析
2016/05/25 Javascript
利用node.js搭建简单web服务器的方法教程
2017/02/20 Javascript
用户管理的设计_jquery的ajax实现二级联动效果
2017/07/13 jQuery
jquery 获取索引值在一定范围的列表方法
2018/01/25 jQuery
nginx+vue.js实现前后端分离的示例代码
2018/02/12 Javascript
JS实现计算小于非负数n的素数的数量算法示例
2019/02/26 Javascript
Vue事件修饰符native、self示例详解
2019/07/09 Javascript
Vue中this.$nextTick的作用及用法
2020/02/04 Javascript
vue中组件通信详解(父子组件, 爷孙组件, 兄弟组件)
2020/07/27 Javascript
vue基于Echarts的拖拽数据可视化功能实现
2020/12/04 Vue.js
javascript实现简单页面倒计时
2021/03/02 Javascript
使用Python生成随机密码的示例分享
2016/02/18 Python
Python实现信用卡系统(支持购物、转账、存取钱)
2016/06/24 Python
Python自定义一个类实现字典dict功能的方法
2019/01/19 Python
Python 3.8 新功能全解
2019/07/25 Python
python matplotlib饼状图参数及用法解析
2019/11/04 Python
Python之Django自动实现html代码(下拉框,数据选择)
2020/03/13 Python
Django DRF路由与扩展功能的实现
2020/06/03 Python
大班幼儿评语大全
2014/04/30 职场文书
学雷锋演讲稿汇总
2014/05/10 职场文书
爱国主义教育演讲稿
2014/08/26 职场文书
2014法院干警廉洁警示教育思想汇报
2014/09/13 职场文书
2015年行政人事部工作总结
2015/05/13 职场文书
Vue中插槽slot的使用方法与应用场景详析
2021/06/08 Vue.js
Python os和os.path模块详情
2022/04/02 Python