Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用item()方法遍历字典的例子
Aug 26 Python
Python+Selenium自动化实现分页(pagination)处理
Mar 31 Python
Python实现对字符串的加密解密方法示例
Apr 29 Python
python利用urllib实现爬取京东网站商品图片的爬虫实例
Aug 24 Python
使用Python为中秋节绘制一块美味的月饼
Sep 11 Python
Python 单例设计模式用法实例分析
Sep 23 Python
python NumPy ndarray二维数组 按照行列求平均实例
Nov 26 Python
tensorflow保持每次训练结果一致的简单实现
Feb 17 Python
Python运行异常管理解决方案
Mar 09 Python
python 实现PIL模块在图片画线写字
May 16 Python
BeautifulSoup中find和find_all的使用详解
Dec 07 Python
python利用文件时间批量重命名照片和视频
Feb 09 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
php 正确解码javascript中通过escape编码后的字符
2010/01/28 PHP
PHP设计模式之装饰器模式实例详解
2018/02/07 PHP
JAVASCRIPT HashTable
2007/01/22 Javascript
javascript数组使用调用方法汇总
2007/12/08 Javascript
修改jQuery.Autocomplete插件 支持中文输入法 避免TAB、ENTER键失效、导致表单提交
2009/10/11 Javascript
javascript scrollTop正解使用方法
2013/11/14 Javascript
用JavaScript实现一个代码简洁、逻辑不复杂的多级树
2014/05/23 Javascript
Javascript实现鼠标右键特色菜单
2015/08/04 Javascript
JS实现样式清新的横排下拉菜单效果
2015/10/09 Javascript
原生JavaScript实现动态省市县三级联动下拉框菜单实例代码
2016/02/03 Javascript
JavaScript中的splice方法用法详解
2016/07/20 Javascript
EasyUI的doCellTip实现鼠标放到单元格上提示单元格内容
2016/08/24 Javascript
Vue项目全局配置页面缓存之按需读取缓存的实现详解
2018/08/01 Javascript
antd-日历组件,前后禁止选择,只能选中间一部分的实例
2020/10/29 Javascript
python批量实现Word文件转换为PDF文件
2018/03/15 Python
更换Django默认的模板引擎为jinja2的实现方法
2018/05/28 Python
使用selenium模拟登录解决滑块验证问题的实现
2019/05/10 Python
numpy.linalg.eig() 计算矩阵特征向量方式
2019/11/29 Python
python dataframe NaN处理方式
2019/12/26 Python
Python多线程获取返回值代码实例
2020/02/17 Python
Python 批量读取文件中指定字符的实现
2020/03/06 Python
利用Python实现学生信息管理系统的完整实例
2020/12/30 Python
html5中svg canvas和图片之间相互转化思路代码
2014/01/24 HTML / CSS
HTML5 Canvas锯齿图代码实例
2014/04/10 HTML / CSS
荷兰网上买鞋:MooieSchoenen.nl
2017/09/12 全球购物
Chinti & Parker官网:奢华羊绒女装和创新针织设计
2021/01/01 全球购物
瑞士首家网上药店折扣店:McDrogerie
2020/12/22 全球购物
PHP如何与mysql建立链接
2013/05/05 面试题
就业协议书范本
2014/04/11 职场文书
我爱我的祖国演讲稿
2014/05/04 职场文书
中学生打架检讨书
2014/10/13 职场文书
暑期社会实践证明书
2014/11/17 职场文书
2015年南京大屠杀纪念日活动总结
2015/03/24 职场文书
浏览器常用基本操作之python3+selenium4自动化测试(基础篇3)
2021/05/21 Python
Element-ui Layout布局(Row和Col组件)的实现
2021/12/06 Vue.js
MySQL外键约束(Foreign Key)案例详解
2022/06/28 MySQL