Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
搭建Python的Django框架环境并建立和运行第一个App的教程
Jul 02 Python
Python操作Access数据库基本步骤分析
Sep 19 Python
详解Python读取配置文件模块ConfigParser
May 11 Python
python 中random模块的常用方法总结
Jul 08 Python
关于Python中浮点数精度处理的技巧总结
Aug 10 Python
完美解决安装完tensorflow后pip无法使用的问题
Jun 11 Python
tensorflow: 查看 tensor详细数值方法
Jun 13 Python
对python创建及引用动态变量名的示例讲解
Nov 10 Python
Python3匿名函数lambda介绍与使用示例
May 18 Python
TensorFlow实现指数衰减学习率的方法
Feb 05 Python
Python参数传递机制传值和传引用原理详解
May 22 Python
python 实现客户端与服务端的通信
Dec 23 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
PHP脚本的10个技巧(1)
2006/10/09 PHP
色色整理的PHP面试题集锦
2012/03/08 PHP
浅析is_writable的php实现
2013/06/18 PHP
PHP读取文件内容后清空文件示例代码
2014/03/18 PHP
PHP exif扩展方法开启详解
2014/07/28 PHP
PHP var关键字相关原理及使用实例解析
2020/07/11 PHP
在线游戏大家来找茬II
2006/09/30 Javascript
javascript 打印页面代码
2009/03/24 Javascript
javascript中的array数组使用技巧
2010/01/31 Javascript
动态加载图片路径 保持JavaScript控件的相对独立性
2010/09/06 Javascript
JQuery FlexiGrid的asp.net完美解决方案 dotNetFlexGrid-.Net原生的异步表格控件
2010/09/12 Javascript
浏览器解析js生成的html出现样式问题的解决方法
2012/04/16 Javascript
通过jQuery源码学习javascript(三)
2012/12/27 Javascript
只需20行代码就可以写出CSS覆盖率测试脚本
2013/04/24 Javascript
微信小程序实战之运维小项目
2017/01/17 Javascript
js实现简单的选项卡效果
2017/02/23 Javascript
js实现年月日表单三级联动
2020/04/17 Javascript
关于javascript作用域的常见面试题分享
2017/06/18 Javascript
Vue.directive()的用法和实例详解
2018/03/04 Javascript
关于vue面试题汇总
2018/03/20 Javascript
JavaScript在web自动化测试中的作用示例详解
2019/08/25 Javascript
关于IDEA中的.VUE文件报错 Export declarations are not supported by current JavaScript version
2020/10/17 Javascript
Vue-Ant Design Vue-普通及自定义校验实例
2020/10/24 Javascript
vue 避免变量赋值后双向绑定的操作
2020/11/07 Javascript
python删除文件示例分享
2014/01/28 Python
在Python中操作字典之update()方法的使用
2015/05/22 Python
python range()函数取反序遍历sequence的方法
2018/06/25 Python
【python】matplotlib动态显示详解
2019/04/11 Python
python 通过手机号识别出对应的微信性别(实例代码)
2019/12/22 Python
关于Tensorflow 模型持久化详解
2020/02/12 Python
Django中的模型类设计及展示示例详解
2020/05/29 Python
html5使用canvas实现图片下载功能的示例代码
2017/08/26 HTML / CSS
HTML5视频播放插件 video.js介绍
2018/09/29 HTML / CSS
学生档案自我鉴定
2013/10/07 职场文书
物理教育专业求职信
2014/06/25 职场文书
SQL Server使用T-SQL语句批处理
2022/05/20 SQL Server