PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python+django实现文件上传
Jan 17 Python
python 全局变量的import机制介绍
Sep 07 Python
详解python 拆包可迭代数据如tuple, list
Dec 29 Python
Python操作Sql Server 2008数据库的方法详解
May 17 Python
Python中pandas模块DataFrame创建方法示例
Jun 20 Python
Python基础知识点 初识Python.md
May 14 Python
Django缓存系统实现过程解析
Aug 02 Python
python读取指定字节长度的文本方法
Aug 27 Python
python递归下载文件夹下所有文件
Aug 31 Python
pymysql 开启调试模式的实现
Sep 24 Python
python中return如何写
Jun 18 Python
Python实现自动装机功能案例分析
Oct 22 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
PHP开发微信支付的代码分享
2014/05/25 PHP
php实现建立多层级目录的方法
2014/07/19 PHP
PHP面向对象程序设计之对象克隆clone和魔术方法__clone()用法分析
2019/06/12 PHP
解决 firefox 不支持 document.all的方法
2007/03/12 Javascript
javascript setTimeout和setInterval 的区别
2009/12/08 Javascript
javascript操作cookie的文章(设置,删除cookies)
2010/04/01 Javascript
框架页面高度自动刷新的Javascript脚本
2013/11/01 Javascript
jQuery学习笔记之jQuery原型属性和方法
2014/06/09 Javascript
JavaScript函数详解
2014/11/17 Javascript
原生javascript实现图片滚动、延时加载功能
2015/01/12 Javascript
百度UEditor编辑器如何关闭抓取远程图片功能
2015/03/03 Javascript
js拆分字符串并将分割的数据放到数组中的方法
2015/05/06 Javascript
js实现选中复选框文字变色的方法
2015/08/14 Javascript
跟我学习javascript创建对象(类)的8种方法
2015/11/20 Javascript
jQuery on()方法绑定动态元素的点击事件实例代码浅析
2016/06/16 Javascript
RequireJS多页面应用实例分析
2016/06/29 Javascript
AngularJS之依赖注入模拟实现
2016/08/19 Javascript
AngularJS指令与控制器之间的交互功能示例
2016/12/14 Javascript
AngularJS 在同一个界面启动多个ng-app应用模块详解
2016/12/20 Javascript
jquery实现文字单行横移或翻转(上下、左右跳转)
2017/01/08 Javascript
使用pm2部署node生产环境的方法步骤
2019/03/09 Javascript
JS检索下拉列表框中被选项目的索引号(selectedIndex)
2019/12/17 Javascript
vue中改变滚动条样式的方法
2020/03/03 Javascript
[01:56]2014DOTA2西雅图邀请赛 MVP外卡赛老队长精辟点评
2014/07/09 DOTA
Python 返回汉字的汉语拼音
2009/02/27 Python
Python实现PS滤镜碎片特效功能示例
2018/01/24 Python
Python(TensorFlow框架)实现手写数字识别系统的方法
2018/05/29 Python
Numpy之文件存取的示例代码
2018/08/03 Python
python中pytest收集用例规则与运行指定用例详解
2019/06/27 Python
浅谈django 重载str 方法
2020/05/19 Python
python中reload重载实例用法
2020/12/15 Python
HTML5新增的表单元素和属性实例解析
2014/07/07 HTML / CSS
企业项目策划书
2014/01/11 职场文书
我爱我家教学反思
2014/05/01 职场文书
写给医生的感谢信
2015/01/22 职场文书
网络安全倡议书(3篇)
2019/09/18 职场文书