Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python记录运行pid,并在需要时kill掉它们的实例
Jan 16 Python
python读取二进制mnist实例详解
May 31 Python
python实现n个数中选出m个数的方法
Nov 13 Python
pywinauto自动化操作记事本
Aug 26 Python
tornado+celery的简单使用详解
Dec 21 Python
Python实现疫情通定时自动填写功能(附代码)
May 27 Python
Python decimal模块使用方法详解
Jun 08 Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 Python
python 利用toapi库自动生成api
Oct 19 Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 Python
python网络爬虫实现发送短信验证码的方法
Feb 25 Python
Python3 多线程(连接池)操作MySQL插入数据
Jun 09 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
PHP基于SimpleXML生成和解析xml的方法示例
2017/07/17 PHP
解决Laravel blade模板转义html标签的问题
2019/09/03 PHP
JS获取scrollHeight问题想到的标准问题
2007/05/27 Javascript
javascript jQuery $.post $.ajax用法
2008/07/09 Javascript
Exjs 入门篇
2010/04/07 Javascript
fancybox1.3.1 基于Jquery的插件在IE中图片显示问题
2010/10/01 Javascript
对xmlHttp对象的理解
2011/01/17 Javascript
jQuery表格行换色的三种实现方法
2011/06/27 Javascript
JavaScript面向对象程序设计三 原型模式(上)
2011/12/21 Javascript
jQuery删除节点的三个方法即remove()detach()和empty()
2013/12/27 Javascript
Node.js中对通用模块的封装方法
2014/06/06 Javascript
JQuery实现动态表格点击按钮表格增加一行
2014/08/24 Javascript
优化Node.js Web应用运行速度的10个技巧
2014/09/03 Javascript
Node.js中使用mongoskin操作mongoDB实例
2014/09/28 Javascript
setTimeout内不支持jquery的选择器的解决方案
2015/04/28 Javascript
深入浅出理解javaScript原型链
2015/05/09 Javascript
浅谈JavaScript 的执行顺序
2015/08/07 Javascript
jsonp跨域请求数据实现手机号码查询实例分析
2015/12/12 Javascript
JavaScript编写带旋转+线条干扰的验证码脚本实例
2016/05/30 Javascript
jQuery UI仿淘宝搜索下拉列表功能
2017/01/10 Javascript
JavaScript选取(picking)和反选(rejecting)对象的属性方法
2017/08/16 Javascript
使用 Javascript 实现浏览器推送提醒功能的示例
2017/11/03 Javascript
js中call()和apply()改变指针问题的讲解
2019/01/17 Javascript
微信小程序 scroll-view的使用案例代码详解
2020/06/11 Javascript
jQuery开发仿QQ版音乐播放器
2020/07/10 jQuery
浅谈Python的条件判断语句if/else语句
2019/03/21 Python
前端canvas水印快速制作(附完整代码)
2019/09/19 HTML / CSS
如果有两个类A,B,怎么样才能使A在发生一个事件的时候通知B
2016/03/12 面试题
如何撰写岗位职责
2014/02/01 职场文书
法人授权委托书
2014/09/16 职场文书
2014年高中教师工作总结
2014/12/19 职场文书
2016年小学生迎国庆广播稿
2015/12/18 职场文书
2016年小学六一儿童节活动总结
2016/04/06 职场文书
MySQL COUNT函数的使用与优化
2021/05/10 MySQL
win11系统中dhcp服务异常什么意思? Win11 DHCP服务异常修复方法
2022/04/08 数码科技
Python自动化工具之实现Excel转Markdown表格
2022/04/08 Python