Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 字符串定义
Sep 25 Python
Python素数检测实例分析
Jun 15 Python
Python实现包含min函数的栈
Apr 29 Python
读取json格式为DataFrame(可转为.csv)的实例讲解
Jun 05 Python
python 解压pkl文件的方法
Oct 25 Python
python实现简单名片管理系统
Nov 30 Python
python3 tkinter实现添加图片和文本
Nov 26 Python
python将四元数变换为旋转矩阵的实例
Dec 04 Python
Python GUI编程学习笔记之tkinter事件绑定操作详解
Mar 30 Python
Django实现将一个字典传到前端显示出来
Apr 03 Python
如何解决安装python3.6.1失败
Jul 01 Python
python爬虫今日热榜数据到txt文件的源码
Feb 23 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
php FPDF类库应用实现代码
2009/03/20 PHP
IIS安装Apache伪静态插件的具体操作图文
2013/07/01 PHP
PHP实现将科学计数法转换为原始数字字符串的方法
2014/12/16 PHP
php封装好的人民币数值转中文大写类
2015/12/20 PHP
详解PHP数据压缩、加解密(pack, unpack)
2016/12/17 PHP
PHP格式化显示时间date()函数代码
2018/10/03 PHP
javascript 变量作用域 代码分析
2009/06/26 Javascript
js的逻辑运算符 ||
2010/05/31 Javascript
js的2种继承方式详解
2014/03/04 Javascript
JavaScript实现找出数组中最长的连续数字序列
2014/09/03 Javascript
纯javascript制作日历控件
2015/07/17 Javascript
jQuery简单操作cookie的插件实例
2016/01/13 Javascript
jQuery animate和CSS3相结合实现缓动追逐效果附源码下载
2016/04/18 Javascript
javascript中this指向详解
2016/04/23 Javascript
动态的9*9乘法表效果的实现代码
2016/05/16 Javascript
JavaScript判断数组是否存在key的简单实例
2016/08/03 Javascript
input 禁止输入特殊字符的四种实现方式
2016/08/24 Javascript
JavaScript运动框架 多值运动(四)
2017/05/18 Javascript
详解基于原生JS验证表单组件xy-form
2019/08/20 Javascript
[13:55]Newbee vs Team Spirit
2018/06/07 DOTA
Python2.x版本中maketrans()方法的使用介绍
2015/05/19 Python
Python两个内置函数 locals 和globals(学习笔记)
2016/08/28 Python
利用Python获取操作系统信息实例
2016/09/02 Python
Python实现抓取网页生成Excel文件的方法示例
2017/08/05 Python
Python字符串格式化的方法(两种)
2017/09/19 Python
django 控制页面跳转的例子
2019/08/06 Python
使用Python代码实现Linux中的ls遍历目录命令的实例代码
2019/09/07 Python
python numpy数组中的复制知识解析
2020/02/03 Python
python opencv 图像边框(填充)添加及图像混合的实现方法(末尾实现类似幻灯片渐变的效果)
2020/03/09 Python
对python中list的五种查找方法说明
2020/07/13 Python
《独坐敬亭山》教学反思
2014/04/08 职场文书
个人授权委托书格式
2014/08/30 职场文书
2014年第四季度入党积极分子思想汇报(十八届四中全会)
2014/11/03 职场文书
上课讲话检讨书范文
2015/05/07 职场文书
pandas中对文本类型数据的处理小结
2021/11/01 Python
Flink 侧流输出源码示例解析
2022/09/23 Servers