pytorch 实现多个Dataloader同时训练


Posted in Python onMay 29, 2021

看代码吧~

pytorch 实现多个Dataloader同时训练

如果两个dataloader的长度不一样,那就加个:

from itertools import cycle

仅使用zip,迭代器将在长度等于最小数据集的长度时耗尽。 但是,使用cycle时,我们将再次重复最小的数据集,除非迭代器查看最大数据集中的所有样本。

pytorch 实现多个Dataloader同时训练

补充:pytorch技巧:自定义数据集 torch.utils.data.DataLoader 及Dataset的使用

本博客中有可直接运行的例子,便于直观的理解,在torch环境中运行即可。

1. 数据传递机制

在 pytorch 中数据传递按一下顺序:

1、创建 datasets ,也就是所需要读取的数据集。

2、把 datasets 传入DataLoader。

3、DataLoader迭代产生训练数据提供给模型。

2. torch.utils.data.Dataset

Pytorch提供两种数据集:

Map式数据集 Iterable式数据集。其中Map式数据集继承torch.utils.data.Dataset,Iterable式数据集继承torch.utils.data.IterableDataset。

本文只介绍 Map式数据集。

一个Map式的数据集必须要重写 __getitem__(self, index)、 __len__(self) 两个方法,用来表示从索引到样本的映射(Map)。 __getitem__(self, index)按索引映射到对应的数据, __len__(self)则会返回这个数据集的长度。

基本格式如下:

import torch.utils.data as data
class VOCDetection(data.Dataset):
    '''
    必须继承data.Dataset类
    '''
    def __init__(self):
        '''
        在这里进行初始化,一般是初始化文件路径或文件列表
        '''
        pass
    def __getitem__(self, index):
        '''
        1. 按照index,读取文件中对应的数据  (读取一个数据!!!!我们常读取的数据是图片,一般我们送入模型的数据成批的,但在这里只是读取一张图片,成批后面会说到)
        2. 对读取到的数据进行数据增强 (数据增强是深度学习中经常用到的,可以提高模型的泛化能力)
        3. 返回数据对 (一般我们要返回 图片,对应的标签) 在这里因为我没有写完整的代码,返回值用 0 代替
        '''
        return 0
    def __len__(self):
        '''
        返回数据集的长度
        '''
        return 0

可直接运行的例子:

import torch.utils.data as data
import numpy as np
x = np.array(range(80)).reshape(8, 10) # 模拟输入, 8个样本,每个样本长度为10
y = np.array(range(8))  # 模拟对应样本的标签, 8个标签 
class Mydataset(data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.idx = list()
        for item in x:
            self.idx.append(item)
        pass
    def __getitem__(self, index):
        input_data = self.idx[index] #可继续进行数据增强,这里没有进行数据增强操作
        target = self.y[index]
        return input_data, target
    def __len__(self):
        return len(self.idx)
datasets = Mydataset(x, y)  # 初始化
print(datasets.__len__())  # 调用__len__() 返回数据的长度
for i in range(len(y)):
    input_data, target = datasets.__getitem__(i)  # 调用__getitem__(index) 返回读取的数据对
    print('input_data%d =' % i, input_data)
    print('target%d = ' % i, target)

结果如下:

pytorch 实现多个Dataloader同时训练

3. torch.utils.data.DataLoader

PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。

该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。

torch.utils.data.DataLoader(onject)的可用参数如下:

1.dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。

2.batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)

3.shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False)

4.sampler (Sampler, optional):从数据集中提取样本的策略。如果指定,“shuffle”必须为false。我没有用过,不太了解。

5.batch_sampler (Sampler, optional):和batch_size、shuffle等参数互斥,一般用默认。

6.num_workers:这个参数必须大于等于0,为0时默认使用主线程读取数据,其他大于0的数表示通过多个进程来读取数据,可以加快数据读取速度,一般设置为2的N次方,且小于batch_size(默认:0)

7.collate_fn (callable, optional): 合并样本清单以形成小批量。用来处理不同情况下的输入dataset的封装。

8.pin_memory (bool, optional):如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存中.

9.drop_last (bool, optional): 如果数据集大小不能被批大小整除,则设置为“true”以除去最后一个未完成的批。如果“false”那么最后一批将更小。(默认:false)

10.timeout(numeric, optional):设置数据读取时间限制,超过这个时间还没读取到数据的话就会报错。(默认:0)

11.worker_init_fn (callable, optional): 每个worker初始化函数(默认:None)

可直接运行的例子:

import torch.utils.data as data
import numpy as np
x = np.array(range(80)).reshape(8, 10) # 模拟输入, 8个样本,每个样本长度为10
y = np.array(range(8))  # 模拟对应样本的标签, 8个标签
class Mydataset(data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.idx = list()
        for item in x:
            self.idx.append(item)
        pass
    def __getitem__(self, index):
        input_data = self.idx[index]
        target = self.y[index]
        return input_data, target
    def __len__(self):
        return len(self.idx)
if __name__ ==('__main__'):
    datasets = Mydataset(x, y)  # 初始化
    dataloader = data.DataLoader(datasets, batch_size=4, num_workers=2) 
    for i, (input_data, target) in enumerate(dataloader):
        print('input_data%d' % i, input_data)
        print('target%d' % i, target)

结果如下:(注意看类别,DataLoader把数据封装为Tensor)

pytorch 实现多个Dataloader同时训练

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现简单的获取图片爬虫功能示例
Jul 12 Python
python简单实例训练(21~30)
Nov 15 Python
pandas 空的dataframe 插入列名的示例
Oct 30 Python
利用Python求阴影部分的面积实例代码
Dec 05 Python
Django+Xadmin构建项目的方法步骤
Mar 06 Python
python3在同一行内输入n个数并用列表保存的例子
Jul 20 Python
django 自定义过滤器(filter)处理较为复杂的变量方法
Aug 12 Python
Win系统PyQt5安装和使用教程
Dec 25 Python
基于Tensorflow高阶读写教程
Feb 10 Python
pycharm设置当前工作目录的操作(working directory)
Feb 14 Python
简单了解python shutil模块原理及使用方法
Apr 28 Python
JAVA SpringMVC实现自定义拦截器
Mar 16 Python
python 如何做一个识别率百分百的OCR
基于PyTorch实现一个简单的CNN图像分类器
May 29 #Python
python 爬取华为应用市场评论
python 开心网和豆瓣日记爬取的小爬虫
May 29 #Python
Python趣味挑战之实现简易版音乐播放器
新手必备Python开发环境搭建教程
Keras多线程机制与flask多线程冲突的解决方案
May 28 #Python
You might like
php 正则 过滤html 的超链接
2009/06/02 PHP
PHP 单引号与双引号的区别
2009/11/24 PHP
PHP和Mysqlweb应用开发核心技术 第1部分 Php基础-3 代码组织和重用2
2011/07/03 PHP
PHP 用session与gd库实现简单验证码生成与验证的类方法
2016/11/15 PHP
实例讲解YII2中多表关联的使用方法
2017/07/21 PHP
非常漂亮的JS代码经典广告
2007/10/21 Javascript
javascript限制文本框只允许输入数字(曾经与现在的方法对比)
2013/01/18 Javascript
jQuery中:image选择器用法实例
2015/01/03 Javascript
为JS扩展Array.prototype.indexOf引发的问题及解决办法
2015/01/21 Javascript
javascript实现获取浏览器版本、操作系统类型
2015/01/29 Javascript
javascript实现十秒钟后注册按钮可点击的方法
2015/05/13 Javascript
全面解析DOM操作和jQuery实现选项移动操作代码分享
2016/06/07 Javascript
AngularJs bootstrap搭载前台框架——准备工作
2016/09/01 Javascript
js实现截图保存图片功能的代码示例
2017/02/16 Javascript
基于VUE选择上传图片并页面显示(图片可删除)
2017/05/25 Javascript
详解如何优雅地在React项目中使用Redux
2017/12/28 Javascript
200行代码实现blockchain 区块链实例详解
2018/03/14 Javascript
在vue中使用axios实现post方式获取二进制流下载文件(实例代码)
2019/12/16 Javascript
node.js 使用 net 模块模拟 websocket 握手进行数据传递操作示例
2020/02/11 Javascript
[00:02]DOTA2新版本使用PA至宝后暴击展示
2014/11/19 DOTA
[09:37]2018DOTA2国际邀请赛寻真——不懈追梦的Team Serenity
2018/08/13 DOTA
python实现字符串中字符分类及个数统计
2018/09/28 Python
浅析python中的迭代与迭代对象
2018/10/08 Python
Python datetime包函数简单介绍
2019/08/28 Python
学习Django知识点分享
2019/09/11 Python
python实现斗地主分牌洗牌
2020/06/22 Python
HTML5是否真的可以取代Flash
2010/02/10 HTML / CSS
欧舒丹英国官网:购买欧舒丹护手霜等明星产品
2017/01/17 全球购物
澳大利亚网上书店:QBD
2021/01/09 全球购物
会议活动邀请函
2014/01/27 职场文书
妇产医师自荐信
2014/01/29 职场文书
小学生暑假感言
2014/02/06 职场文书
幼儿园小班教师个人工作总结
2015/02/06 职场文书
学校党支部公开承诺书
2015/04/30 职场文书
户外拓展训练感想
2015/08/07 职场文书
Python 数据可视化神器Pyecharts绘制图像练习
2022/02/28 Python