Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中django框架通过正则搜索页面上email地址的方法
Mar 21 Python
浅谈python 线程池threadpool之实现
Nov 17 Python
python遍历一个目录,输出所有的文件名的实例
Apr 23 Python
Python实现获取汉字偏旁部首的方法示例【测试可用】
Dec 18 Python
详解从Django Rest Framework响应中删除空字段
Jan 11 Python
Python3 关于pycharm自动导入包快捷设置的方法
Jan 16 Python
python tkinter实现界面切换的示例代码
Jun 14 Python
基于python的Paxos算法实现
Jul 03 Python
tensorflow之变量初始化(tf.Variable)使用详解
Feb 06 Python
Django如何使用jwt获取用户信息
Apr 21 Python
常用的10个Python实用小技巧
Aug 10 Python
python UDF 实现对csv批量md5加密操作
Jan 01 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
PHP脚本的10个技巧(2)
2006/10/09 PHP
PHP中cookies使用指南
2007/03/16 PHP
什么是MVC,好东西啊
2007/05/03 PHP
php中使用redis队列操作实例代码
2013/02/07 PHP
PHP rsa加密解密使用方法
2015/04/27 PHP
PHP开发中解决并发问题的几种实现方法分析
2017/11/13 PHP
PHP PDOStatement::errorInfo讲解
2019/01/31 PHP
Laravel5.1 框架Middleware中间件基本用法实例分析
2020/01/04 PHP
PHP实现倒计时功能
2020/11/16 PHP
Javascript 面向对象 继承
2010/05/13 Javascript
Jquery判断IE6等浏览器的代码
2011/04/05 Javascript
jquery快捷动态绑定键盘事件的操作函数代码
2013/10/17 Javascript
javascript html实现网页版日历代码
2016/03/08 Javascript
JavaScript基于自定义函数判断变量类型的实现方法
2016/11/23 Javascript
javascript事件的绑定基础实例讲解(34)
2017/02/14 Javascript
Vue.js如何实现路由懒加载浅析
2017/08/14 Javascript
vue cli webpack中使用sass的方法
2018/02/24 Javascript
jQuery实现浏览器之间跳转并传递参数功能【支持中文字符】
2018/03/28 jQuery
vue父组件触发事件改变子组件的值的方法实例详解
2019/05/07 Javascript
vue实现分页栏效果
2019/06/28 Javascript
关于JS解构的5种有趣用法
2019/09/05 Javascript
Windows下Python使用Pandas模块操作Excel文件的教程
2016/05/31 Python
python微信跳一跳游戏辅助代码解析
2018/01/29 Python
分享Pycharm中一些不为人知的技巧
2018/04/03 Python
python对excel文档去重及求和的实例
2018/04/18 Python
python将字符串以utf-8格式保存在txt文件中的方法
2018/10/30 Python
python 删除字符串中连续多个空格并保留一个的方法
2018/12/22 Python
在windows下使用python进行串口通讯的方法
2019/07/02 Python
python issubclass 和 isinstance函数
2019/07/25 Python
Django ORM 查询管理器源码解析
2019/08/05 Python
python利用dlib获取人脸的68个landmark
2019/11/27 Python
Html5剪切板功能的实现代码
2018/06/29 HTML / CSS
Skyscanner阿联酋:全球领先的旅游搜索平台
2017/11/25 全球购物
测绘工程个人的自我评价
2013/11/10 职场文书
4S店销售内勤岗位职责
2015/04/13 职场文书
导游词之湖北武当山
2019/09/23 职场文书