Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现将n个点均匀地分布在球面上的方法
Mar 12 Python
python操作ssh实现服务器日志下载的方法
Jun 03 Python
图文讲解选择排序算法的原理及在Python中的实现
May 04 Python
Python实现读取Properties配置文件的方法
Mar 29 Python
python3 读写文件换行符的方法
Apr 09 Python
python的pytest框架之命令行参数详解(下)
Jun 27 Python
python3的print()函数的用法图文讲解
Jul 16 Python
Django如何实现上传图片功能
Aug 16 Python
python实现12306登录并保存cookie的方法示例
Dec 17 Python
Pycharm最常用的快捷键及使用技巧
Mar 05 Python
python em算法的实现
Oct 03 Python
python selenium 获取接口数据的实现
Dec 07 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
全国FM电台频率大全 - 25 云南省
2020/03/11 无线电
过滤掉PHP数组中的重复值的实现代码
2011/07/17 PHP
php格式化json函数示例代码
2016/05/12 PHP
PHP利用超级全局变量$_POST来接收表单数据的实例
2016/11/05 PHP
微信 getAccessToken方法详解及实例
2016/11/23 PHP
Yii2中datetime类的使用
2016/12/17 PHP
PHP微信分享开发详解
2017/01/14 PHP
TP5框架页面跳转样式操作示例
2020/04/05 PHP
JavaScript DOM学习第一章 W3C DOM简介
2010/02/19 Javascript
用jquery写的菜单从左往右滑动出现
2014/04/11 Javascript
jQuery模拟点击A标记示例参考
2014/04/17 Javascript
JavaScript设置获取和设置属性的方法
2015/03/04 Javascript
Javascript闭包实例详解
2015/11/29 Javascript
JS常用算法实现代码
2016/11/14 Javascript
Vue.js实现价格计算器功能
2020/03/30 Javascript
微信小程序模版渲染详解
2018/01/26 Javascript
基于JavaScript实现单例模式
2019/10/30 Javascript
JQuery中的常用事件、对象属性与使用方法分析
2019/12/23 jQuery
swiper自定义分页器的样式
2020/09/14 Javascript
跟老齐学Python之玩转字符串(2)
2014/09/14 Python
使用C语言来扩展Python程序和Zope服务器的教程
2015/04/14 Python
Python使用turtule画五角星的方法
2015/07/09 Python
Phantomjs抓取渲染JS后的网页(Python代码)
2016/05/13 Python
如何利用python制作时间戳转换工具详解
2018/09/12 Python
python分数表示方式和写法
2019/06/26 Python
python框架flask表单实现详解
2019/11/04 Python
python读取多层嵌套文件夹中的文件实例
2020/02/27 Python
pymysql之cur.fetchall() 和cur.fetchone()用法详解
2020/05/15 Python
django实现日志按日期分割
2020/05/21 Python
CSS3 :nth-child()伪类选择器实现奇偶行显示不同样式
2013/11/05 HTML / CSS
解释一下钝化(Swap out)
2016/12/26 面试题
送货司机岗位职责
2013/12/11 职场文书
乡下人家教学反思
2014/02/01 职场文书
2015小学教育教学工作总结
2015/07/21 职场文书
2016年度基层党建工作公开承诺书
2016/03/25 职场文书
Anaconda安装pytorch及配置PyCharm 2021环境
2021/06/04 Python