Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python错误:AttributeError: 'module' object has no attribute 'setdefaultencoding'问题的解决方法
Aug 22 Python
讲解Python中if语句的嵌套用法
May 14 Python
通过python+selenium3实现浏览器刷简书文章阅读量
Dec 26 Python
基于Python的ModbusTCP客户端实现详解
Jul 13 Python
Python 根据日志级别打印不同颜色的日志的方法示例
Aug 08 Python
python 数据提取及拆分的实现代码
Aug 26 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 Python
Pytorch在NLP中的简单应用详解
Jan 08 Python
Django QuerySet查询集原理及代码实例
Jun 13 Python
win10安装python3.6的常见问题
Jul 01 Python
Python 正则模块详情
Nov 02 Python
python 使用pandas读取csv文件的方法
Dec 24 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
PHP检测用户是否关闭浏览器的方法
2016/02/14 PHP
php使用socket调用http和smtp协议实例小结
2019/07/26 PHP
js中cookie的使用详细分析
2008/05/28 Javascript
jQuery 注意事项 与原因分析
2009/04/24 Javascript
网页禁用右键实现代码(JavaScript代码)
2009/10/29 Javascript
js 模拟气泡屏保效果代码
2010/07/10 Javascript
Javascript绝句欣赏 一些经典的js代码
2012/02/22 Javascript
jQuery带箭头提示框tooltips插件集锦
2014/11/17 Javascript
javascript内置对象操作详解
2015/02/04 Javascript
基于JavaScript制作霓虹灯文字 代码 特效
2015/09/01 Javascript
日常收藏的jquery技巧
2015/12/02 Javascript
JavaScript中的ajax功能的概念和示例详解
2016/10/17 Javascript
javascript原生封装一个淡入淡出效果的函数测试实例代码
2018/03/19 Javascript
vue 父组件调用子组件方法及事件
2018/03/29 Javascript
Vue使用vux-ui自定义表单验证遇到的问题及解决方法
2018/05/10 Javascript
微信小程序表单弹窗实例
2018/07/19 Javascript
小程序点击图片实现自动播放视频
2020/05/29 Javascript
vue(2.x,3.0)配置跨域代理
2019/11/27 Javascript
结合axios对项目中的api请求进行封装操作
2020/09/21 Javascript
[07:20]2014DOTA2西雅图国际邀请赛 选手讲解积分赛第二天
2014/07/11 DOTA
跟老齐学Python之dict()的操作方法
2014/09/24 Python
Python中optparser库用法实例详解
2018/01/26 Python
PyQt5每天必学之进度条效果
2018/04/19 Python
Python全局锁中如何合理运用多线程(多进程)
2019/11/06 Python
pytorch sampler对数据进行采样的实现
2019/12/31 Python
瑞贝卡·泰勒官方网站:Rebecca Taylor
2016/09/24 全球购物
法律工作求职自荐信
2013/10/31 职场文书
室内设计专业个人的自我评价
2013/12/18 职场文书
学校清明节活动总结
2014/07/04 职场文书
简单租房协议书范本
2014/08/20 职场文书
社区护士演讲稿
2014/08/27 职场文书
公安领导班子四风问题个人整改措施思想汇报
2014/10/09 职场文书
扩展多台相同的Web服务器
2021/04/01 Servers
面试中canvas绘制图片模糊图片问题处理
2022/03/13 Javascript
Elasticsearch 批量操作
2022/04/19 Python
AndroidStudio图片压缩工具ImgCompressPlugin使用实例
2022/08/05 Java/Android