Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python脚本来控制Windows Azure的简单教程
Apr 16 Python
Python中atexit模块的基本使用示例
Jul 08 Python
Python PyQt5标准对话框用法示例
Aug 23 Python
Python使用正则表达式过滤或替换HTML标签的方法详解
Sep 25 Python
python3实现公众号每日定时发送日报和图片
Feb 24 Python
Python+OpenCV实现图像融合的原理及代码
Dec 03 Python
Python学习笔记基本数据结构之序列类型list tuple range用法分析
Jun 08 Python
python基于K-means聚类算法的图像分割
Oct 30 Python
Django User 模块之 AbstractUser 扩展详解
Mar 11 Python
python爬虫基础知识点整理
Jun 02 Python
PyCharm中关于安装第三方包的三个建议
Sep 17 Python
python中zip()函数遍历多个列表方法
Feb 18 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
评分9.0以上的动画电影,剧情除了经典还很燃
2020/03/04 日漫
PHP与javascript的两种交互方式
2006/10/09 PHP
php中time()与$_SERVER[REQUEST_TIME]用法区别
2014/11/19 PHP
Yii中的relations数据关联查询及统计功能用法详解
2016/07/14 PHP
总结PHP中数值计算的注意事项
2016/08/14 PHP
发布一个高效的JavaScript分析、压缩工具 JavaScript Analyser
2007/11/30 Javascript
javascript web对话框与弹出窗口
2009/02/22 Javascript
JavaScript中的prototype使用说明
2010/04/13 Javascript
jQuery点击tr实现checkbox选中的方法
2013/03/19 Javascript
js计算两个时间之间天数差的实例代码
2013/11/19 Javascript
JavaScript中使用arguments获得函数传参个数实例
2014/08/27 Javascript
Jquery为DIV添加click事件的简单实例
2016/06/02 Javascript
jQuery实现拖拽页面元素并将其保存到cookie的方法
2016/06/12 Javascript
Angular外部使用js调用Angular控制器中的函数方法或变量用法示例
2016/08/05 Javascript
JS实现页面载入时随机显示图片效果
2016/09/07 Javascript
简单实现AngularJS轮播图效果
2020/04/10 Javascript
js装饰设计模式学习心得
2018/02/17 Javascript
JavaScript基于对象方法实现数组去重及排序操作示例
2018/07/10 Javascript
jquery实现动态添加附件功能
2018/10/23 jQuery
layer的prompt弹出框,点击回车,触发确定事件的方法
2019/09/06 Javascript
vue-froala-wysiwyg 富文本编辑器功能
2019/09/19 Javascript
React 实现车牌键盘的示例代码
2019/12/20 Javascript
[41:12]Liquid vs Secret 2019国际邀请赛淘汰赛 败者组 BO3 第一场 8.24
2019/09/10 DOTA
Python的加密模块md5、sha、crypt使用实例
2014/09/28 Python
Python实现网站注册验证码生成类
2017/06/08 Python
python实现微信远程控制电脑
2018/02/22 Python
python3+PyQt5实现自定义分数滑块部件
2018/04/24 Python
PHP实现发送和接收JSON请求
2018/06/07 Python
django模型动态修改参数,增加 filter 字段的方式
2020/03/16 Python
CSS3之背景尺寸Background-size使用介绍
2013/10/14 HTML / CSS
岗位职责的含义
2013/11/17 职场文书
硕士生工作推荐信
2014/03/07 职场文书
导游词之沈阳清昭陵
2019/12/28 职场文书
详解如何使用Node.js实现热重载页面
2021/05/06 Javascript
Python OpenCV实现传统图片格式与base64转换
2021/06/13 Python
Python 居然可以在 Excel 中画画你知道吗
2022/02/15 Python