PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
通过python下载FTP上的文件夹的实现代码
Feb 10 Python
python中使用OpenCV进行人脸检测的例子
Apr 18 Python
python django 增删改查操作 数据库Mysql
Jul 27 Python
详解Python中的Numpy、SciPy、MatPlotLib安装与配置
Nov 17 Python
python判断一个集合是否为另一个集合的子集方法
May 04 Python
python使用turtle绘制分形树
Jun 22 Python
使用pandas把某一列的字符值转换为数字的实例
Jan 29 Python
python 类之间的参数传递方式
Dec 20 Python
python获取命令行参数实例方法讲解
Nov 02 Python
Python-openpyxl表格读取写入的案例详解
Nov 02 Python
教你用Python爬取英雄联盟皮肤原画
Jun 13 Python
pycharm无法安装cv2模块问题
May 20 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
使用PHP 5.0创建图形的巧妙方法
2010/10/12 PHP
php去除字符串中空字符的常用方法小结
2015/03/17 PHP
分享10段PHP常用代码
2015/11/11 PHP
PHP中的访问修饰符简单比较
2019/02/02 PHP
PHP实现数组和对象的相互转换操作示例
2019/03/20 PHP
Web层改进II-用xmlhttp 无声息提交复杂表单
2007/01/22 Javascript
讲两件事:1.this指针的用法小探. 2.ie的attachEvent和firefox的addEventListener在事件处理上的区别
2007/04/12 Javascript
Prototype使用指南之selector.js说明
2008/10/26 Javascript
初窥JQuery(一)jquery选择符 必备知识点
2010/11/25 Javascript
javascript得到当前页的来路即前一页地址的方法
2014/02/18 Javascript
JS实现间歇滚动的运动效果实例
2016/12/22 Javascript
jQuery实现的简单排序功能示例【冒泡排序】
2017/01/13 Javascript
jQuery插件开发发送短信倒计时功能代码
2017/05/09 jQuery
探索webpack模块及webpack3新特性
2017/09/18 Javascript
Bootstrap图片轮播效果详解
2017/10/17 Javascript
解析Vue2 dist 目录下各个文件的区别
2017/11/22 Javascript
深入理解nodejs搭建静态服务器(实现命令行)
2019/02/05 NodeJs
javascript+HTML5 canvas绘制时钟功能示例
2019/05/15 Javascript
小程序中this.setData的使用和注意事项
2019/08/28 Javascript
javascript的delete运算符知识点总结
2019/11/19 Javascript
Python3 操作符重载方法示例
2017/11/23 Python
Python断言assert的用法代码解析
2018/02/03 Python
Python装饰器用法实例总结
2018/02/07 Python
浅谈Python接口对json串的处理方法
2018/12/19 Python
Python使用Paramiko控制liunx第三方库
2020/05/20 Python
浅谈keras.callbacks设置模型保存策略
2020/06/18 Python
css3的transition属性详解
2014/12/15 HTML / CSS
购买正版游戏和游戏激活码:Green Man Gaming
2019/11/06 全球购物
高三历史教学反思
2014/01/09 职场文书
外贸采购员岗位职责
2014/03/08 职场文书
超市优秀员工事迹材料
2014/05/01 职场文书
公司授权委托书样本
2014/09/15 职场文书
英语投诉信范文
2015/07/03 职场文书
个人工作决心书
2015/09/22 职场文书
党风廉政建设心得体会(2016最新版)
2016/01/22 职场文书
MongoDB 常用的crud操作语句
2021/06/20 MongoDB