Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python线程锁(thread)学习示例
Dec 04 Python
Python语言的12个基础知识点小结
Jul 10 Python
Python入门篇之正则表达式
Oct 20 Python
python中dir函数用法分析
Apr 17 Python
Python格式化日期时间操作示例
Jun 28 Python
Python实现DDos攻击实例详解
Feb 02 Python
python UDP(udp)协议发送和接收的实例
Jul 22 Python
Python基于OpenCV实现人脸检测并保存
Jul 23 Python
python实现低通滤波器代码
Feb 26 Python
Python龙贝格法求积分实例
Feb 29 Python
解决jupyter notebook 出现In[*]的问题
Apr 13 Python
Python pytesseract验证码识别库用法解析
Jun 29 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
将PHP作为Shell脚本语言使用
2006/10/09 PHP
用PHP将数据导入到Foxmail的实现代码
2010/09/05 PHP
PHP url 加密解密函数代码
2011/08/26 PHP
使用PHPMyAdmin修复论坛数据库的图文方法
2012/01/09 PHP
window.location.hash 属性使用说明
2010/03/20 Javascript
javascript折半查找详解
2015/01/26 Javascript
3种js实现string的substring方法
2015/11/09 Javascript
理解JavaScript原型链
2016/10/25 Javascript
浅谈jQuery this和$(this)的区别及获取$(this)子元素对象的方法
2016/11/29 Javascript
原生js实现电商侧边导航效果
2017/01/19 Javascript
JavaScript实现256色转灰度图
2017/02/22 Javascript
vue-dialog的弹出层组件
2020/05/25 Javascript
原生JS实现圆环拖拽效果
2017/04/07 Javascript
xmlplus组件设计系列之文本框(TextBox)(3)
2017/05/03 Javascript
JS实现按钮添加背景音乐示例代码
2017/10/17 Javascript
webpack 静态资源集中输出的方法示例
2018/11/09 Javascript
vue计算属性computed、事件、监听器watch的使用讲解
2019/01/21 Javascript
深入学习JavaScript中的bom
2019/05/27 Javascript
vuex实现购物车功能
2020/06/28 Javascript
基于Python的XSS测试工具XSStrike使用方法
2017/07/29 Python
python开启debug模式的方法
2019/06/27 Python
Python爬虫 bilibili视频弹幕提取过程详解
2019/07/31 Python
Python调用Windows命令打印文件
2020/02/07 Python
Python如何实现自带HTTP文件传输服务
2020/07/08 Python
详解CSS3中字体平滑处理和抗锯齿渲染
2017/03/29 HTML / CSS
CK美国官网:Calvin Klein
2016/08/26 全球购物
Ralph Lauren英国官方网站:Ralph Lauren UK
2018/04/03 全球购物
金属材料工程个人求职的自我评价
2013/12/04 职场文书
幼儿园中班教学反思
2014/02/10 职场文书
《广玉兰》教学反思
2014/04/14 职场文书
查摆问题整改措施
2014/10/24 职场文书
2014年高校辅导员工作总结
2014/12/09 职场文书
离婚协议书样本
2015/01/26 职场文书
2015年信息化建设工作总结
2015/07/23 职场文书
初一英语教学反思
2016/02/15 职场文书
如何在CocosCreator里画个炫酷的雷达图
2021/04/16 Javascript