Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python以环状形式组合排列图片并输出的方法
Mar 17 Python
简单总结Python中序列与字典的相同和不同之处
Jan 19 Python
node.js获取参数的常用方法(总结)
May 29 Python
Python二叉树的定义及常用遍历算法分析
Nov 24 Python
Python2.X/Python3.X中urllib库区别讲解
Dec 19 Python
python获取文件真实链接的方法,针对于302返回码
May 14 Python
Python实现的远程登录windows系统功能示例
Jun 21 Python
python selenium循环登陆网站的实现
Nov 04 Python
解决TensorFlow GPU版出现OOM错误的问题
Feb 03 Python
Python基于pandas绘制散点图矩阵代码实例
Jun 04 Python
使用pytorch 筛选出一定范围的值
Jun 28 Python
Python eval函数介绍及用法
Nov 09 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
总结PHP中DateTime的常用方法
2016/08/11 PHP
PHP与以太坊交互详解
2018/08/24 PHP
js停止输出代码
2008/07/20 Javascript
js控制表单不能输入空格的小例子
2013/11/20 Javascript
JS中的数组的sort方法使用示例
2014/01/22 Javascript
微信小程序 wx.uploadFile无法上传解决办法
2016/12/14 Javascript
jquery代码规范让代码越来越好看
2017/02/03 Javascript
jQuery插件FusionCharts实现的MSBar2D图效果示例【附demo源码】
2017/03/24 jQuery
详解Node.js access_token的获取、存储及更新
2017/06/20 Javascript
解决AjaxFileupload 上传时会出现连接重置的问题
2017/07/07 Javascript
JavaScript-定时器0~9抽奖系统详解(代码)
2017/08/16 Javascript
vue2.0移除或更改的一些东西(移除index key)
2017/08/28 Javascript
利用jQuery实现简单的拖曳效果实例代码
2017/10/20 jQuery
有趣的JavaScript隐式类型转换操作实例分析
2020/05/02 Javascript
Openlayers学习之加载鹰眼控件
2020/09/28 Javascript
[17:00]DOTA2 HEROS教学视频教你分分钟做大人-帕克
2014/06/10 DOTA
python+mysql实现简单的web程序
2014/09/11 Python
Python实现自动为照片添加日期并分类的方法
2017/09/30 Python
运动检测ViBe算法python实现代码
2018/01/09 Python
pycharm 将django中多个app放到同个文件夹apps的处理方法
2018/05/30 Python
深入学习python多线程与GIL
2019/08/26 Python
python内置模块collections知识点总结
2019/12/19 Python
Python实现Word表格转成Excel表格的示例代码
2020/04/16 Python
Python改变对象的字符串显示的方法
2020/08/01 Python
python request 模块详细介绍
2020/11/10 Python
J2SDK1.5与J2SDK5.0有什么区别
2012/09/19 面试题
Ref与out有什么不同
2012/11/24 面试题
银行求职信
2014/05/31 职场文书
启动仪式策划方案
2014/06/14 职场文书
学雷锋活动倡议书
2014/08/30 职场文书
酒店辞职信怎么写
2015/02/27 职场文书
2015年数学教研工作总结
2015/07/22 职场文书
采购部2015年度工作总结
2015/07/24 职场文书
ORACLE数据库对long类型字段进行模糊匹配的解决思路
2021/04/07 Oracle
MySQL中CURRENT_TIMESTAMP的使用方式
2021/11/27 MySQL
MyBatis核心源码深度剖析SQL语句执行过程
2022/05/20 Java/Android