Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python从零实现贝叶斯分类器的机器学习的教程
Mar 31 Python
Python中的lstrip()方法使用简介
May 19 Python
Python聊天室实例程序分享
Jan 05 Python
Python图像处理之图像的读取、显示与保存操作【测试可用】
Jan 04 Python
使用python实现语音文件的特征提取方法
Jan 09 Python
Python之lambda匿名函数及map和filter的用法
Mar 05 Python
Flask框架中request、请求钩子、上下文用法分析
Jul 23 Python
python写程序统计词频的方法
Jul 29 Python
在Python中画图(基于Jupyter notebook的魔法函数)
Oct 28 Python
关于python中plt.hist参数的使用详解
Nov 28 Python
Python模块的制作方法实例分析
Dec 21 Python
Python从文件中读取数据的方法步骤
Nov 18 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
第十三节 对象串行化 [13]
2006/10/09 PHP
ThinkPHP3.1数据CURD操作快速入门
2014/06/19 PHP
PHP编程中的常见漏洞和代码实例
2014/08/06 PHP
WordPress中缩略图的使用以及相关技巧
2015/11/24 PHP
laravel5.0在linux下解决.htaccess无效和去除index.php的问题
2019/10/16 PHP
javascript SpiderMonkey中的函数序列化如何进行
2012/12/05 Javascript
JavaScript设计模式之装饰者模式介绍
2014/12/28 Javascript
微信小程序之多文件下载的简单封装示例
2018/01/29 Javascript
jQuery插件jsonview展示json数据
2018/05/26 jQuery
JS中的模糊查询功能
2019/12/08 Javascript
Python中基本的日期时间处理的学习教程
2015/10/16 Python
利用matplotlib+numpy绘制多种绘图的方法实例
2017/05/03 Python
Python3 socket同步通信简单示例
2017/06/07 Python
python生成excel的实例代码
2017/11/08 Python
python中的闭包函数
2018/02/09 Python
python Spyder界面无法打开的解决方法
2018/04/27 Python
Selenium控制浏览器常见操作示例
2018/08/13 Python
python3调用百度翻译API实现实时翻译
2018/08/16 Python
解决jupyter运行pyqt代码内核重启的问题
2020/04/16 Python
python对接ihuyi实现短信验证码发送
2020/05/10 Python
Python如何爬取qq音乐歌词到本地
2020/06/01 Python
通过代码实例了解Python sys模块
2020/09/14 Python
Python中的特殊方法以及应用详解
2020/09/20 Python
matplotlib 使用 plt.savefig() 输出图片去除旁边的空白区域
2021/01/05 Python
英国最大的专业户外零售商:Mountain Warehouse
2018/06/06 全球购物
皇家阿尔伯特瓷器美国官网:Royal Albert美国
2020/02/16 全球购物
建筑班组长岗位职责
2014/01/02 职场文书
十佳教师事迹材料
2014/01/11 职场文书
优秀护士演讲稿
2014/04/30 职场文书
机械加工与数控专业自荐书
2014/06/04 职场文书
改进作风怎么办发言材料
2014/08/17 职场文书
预备党员转正思想汇报
2014/09/26 职场文书
清明节文明祭祀倡议书
2015/04/28 职场文书
幼儿园园务工作总结2015
2015/05/18 职场文书
Python游戏开发实例之graphics实现AI五子棋
2021/11/01 Python
Win11右下角图标点了没反应怎么办?Win11点击右下角图标无反应解决方法汇总
2022/07/07 数码科技