Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python判断、获取一张图片主色调的2个实例
Apr 10 Python
python使用threading获取线程函数返回值的实现方法
Nov 15 Python
python获取网页中所有图片并筛选指定分辨率的方法
Mar 31 Python
python-docx修改已存在的Word文档的表格的字体格式方法
May 08 Python
python如何发布自已pip项目的方法步骤
Oct 09 Python
Pandas 按索引合并数据集的方法
Nov 15 Python
python之验证码生成(gvcode与captcha)
Jan 02 Python
Python高级编程之继承问题详解(super与mro)
Nov 19 Python
如何使用Python脚本实现文件拷贝
Nov 20 Python
Django 设置多环境配置文件载入问题
Feb 25 Python
python使用smtplib模块发送邮件
Dec 17 Python
pytorch 计算Parameter和FLOP的操作
Mar 04 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
世界第一个无线广播电台 KDKA
2021/03/01 无线电
PHP文件缓存smarty模板应用实例分析
2016/02/26 PHP
Laravel程序架构设计思路之使用动作类
2018/06/07 PHP
PDO::beginTransaction讲解
2019/01/27 PHP
JavaScript面向对象之Prototypes和继承
2012/07/12 Javascript
JQuery调webservice实现邮箱验证(检测是否可用)
2013/05/21 Javascript
cookie.js 加载顺序问题怎么才有效
2013/07/31 Javascript
js导出txt示例代码
2014/01/14 Javascript
javascript使用正则表达式检测IP地址
2014/12/03 Javascript
JavaScript实现点击按钮直接打印
2016/01/06 Javascript
JS数组去掉重复数据只保留一条的实现代码
2016/08/11 Javascript
浅谈js的ajax的异步和同步请求的问题
2016/10/07 Javascript
详解Angular的8个主要构造块
2017/06/20 Javascript
Node.js使用Koa搭建 基础项目
2018/01/08 Javascript
微信小程序异步API为Promise简化异步编程的操作方法
2018/08/14 Javascript
layui-laydate时间日历控件使用方法详解
2018/11/15 Javascript
解决layer弹出层自适应页面大小的问题
2019/09/16 Javascript
Python中os和shutil模块实用方法集锦
2014/05/13 Python
利用Python破解斗地主残局详解
2017/06/30 Python
基于Python代码编辑器的选用(详解)
2017/09/13 Python
完美解决python3.7 pip升级 拒绝访问问题
2019/07/12 Python
Python Opencv任意形状目标检测并绘制框图
2019/07/23 Python
Python实现疫情通定时自动填写功能(附代码)
2020/05/27 Python
python中的yield from语法快速学习
2020/11/06 Python
canvas绘制树形结构可视图形的实现
2020/04/03 HTML / CSS
联想韩国官网:Lenovo Korea
2018/05/10 全球购物
都柏林通行卡/城市通票:The Dublin Pass
2020/02/16 全球购物
会计大学生职业生涯规划书范文
2014/01/13 职场文书
拓展策划方案
2014/06/03 职场文书
社区娱乐活动方案
2014/08/21 职场文书
党员个人查摆剖析材料
2014/10/16 职场文书
合同权益转让协议书模板
2014/11/18 职场文书
先进个人事迹材料
2014/12/29 职场文书
Redis持久化与主从复制的实践
2021/04/27 Redis
python异常中else的实例用法
2021/06/15 Python
Win10 最新稳定版本 21H2开始推送
2022/04/19 数码科技