Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python33 urllib2使用方法细节讲解
Dec 03 Python
详解Python的Django框架中的通用视图
May 04 Python
python协程用法实例分析
Jun 04 Python
python爬取淘宝商品详情页数据
Feb 23 Python
pycharm 将django中多个app放到同个文件夹apps的处理方法
May 30 Python
python查看模块安装位置的方法
Oct 16 Python
python时间序列按频率生成日期的方法
May 14 Python
python3常用的数据清洗方法(小结)
Oct 31 Python
Python协程 yield与协程greenlet简单用法示例
Nov 22 Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 Python
python numpy矩阵信息说明,shape,size,dtype
May 22 Python
详解python模块pychartdir安装及导入问题
Oct 22 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
火车头采集器3.0采集图文教程
2007/03/17 PHP
php网站地图生成类示例
2014/01/13 PHP
getJSON跨域SyntaxError问题分析
2014/08/07 PHP
php编写批量生成不重复的卡号密码代码
2015/05/14 PHP
win10环境PHP 7 安装配置【教程】
2016/05/09 PHP
PHP设计模式之委托模式定义与用法简单示例
2018/08/13 PHP
win10下 php安装seaslog扩展的详细步骤
2020/12/04 PHP
jquery常用技巧及常用方法列表集合
2011/04/06 Javascript
jquery1.10给新增元素绑定事件的方法
2014/03/06 Javascript
jquery easyui 对于开始时间小于结束时间的判断示例
2014/03/22 Javascript
jquery 扑捉回车键事件代码
2014/04/24 Javascript
javascript屏蔽右键代码
2014/05/15 Javascript
JavaScript使用RegExp进行正则匹配的方法
2015/07/11 Javascript
JS实现焦点图轮播效果的方法详解
2016/12/19 Javascript
详解React之父子组件传递和其它一些要点
2018/06/25 Javascript
vue-cli 首屏加载优化问题
2018/11/06 Javascript
快速搭建Node.js(Express)用户注册、登录以及授权的方法
2019/05/09 Javascript
vue和iview实现Scroll 数据无限滚动功能
2019/10/31 Javascript
Vue 实现对quill-editor组件中的工具栏添加title
2020/08/03 Javascript
js闭包的9个使用场景
2020/12/29 Javascript
python实现按长宽比缩放图片
2018/06/07 Python
解决python3读取Python2存储的pickle文件问题
2018/10/25 Python
python多进程控制学习小结
2018/10/31 Python
对python制作自己的数据集实例讲解
2018/12/12 Python
在OpenCV里使用Camshift算法的实现
2019/11/22 Python
python 实现多维数组转向量
2019/11/30 Python
意大利高端时尚买手店:Stefania Mode
2018/03/01 全球购物
耐克中国官方商城:Nike中国
2018/10/18 全球购物
英国玛莎百货澳大利亚:Marks & Spencer Australia
2019/08/30 全球购物
东南亚冒险旅行与活动:Adventoro
2019/10/16 全球购物
测试工程师程序员求职信范文
2014/02/20 职场文书
政风行风建设责任书
2014/07/23 职场文书
股东出资证明书(正规版)
2014/09/24 职场文书
七年级上册语文教学计划
2015/01/22 职场文书
大学生英文求职信范文
2015/03/19 职场文书
2015年基层党支部工作总结
2015/05/21 职场文书