Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中unittest用法实例
Sep 25 Python
从Python程序中访问Java类的简单示例
Apr 20 Python
python中argparse模块用法实例详解
Jun 03 Python
Django返回json数据用法示例
Sep 18 Python
Python批量查询域名是否被注册过
Jun 21 Python
Python通过matplotlib画双层饼图及环形图简单示例
Dec 15 Python
浅述python中深浅拷贝原理
Sep 18 Python
使用Python实现在Windows下安装Django
Oct 17 Python
基于python修改srt字幕的时间轴
Feb 03 Python
python利用os模块编写文件复制功能——copy()函数用法
Jul 13 Python
删除pycharm鼠标右键快捷键打开项目的操作
Jan 16 Python
Django中session进行权限管理的使用
Jul 09 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
php获取数组长度的方法(有实例)
2013/10/27 PHP
php使用memcoder将视频转成mp4格式的方法
2015/03/12 PHP
js select常用操作控制代码
2010/03/16 Javascript
High Performance JavaScript(高性能JavaScript)读书笔记分析
2011/05/05 Javascript
基于zepto.js实现仿手机QQ空间的大图查看组件ImageView.js详解
2015/03/05 Javascript
JavaScript中的setUTCDate()方法使用详解
2015/06/11 Javascript
javascript实现跨域的方法汇总
2015/06/25 Javascript
常用javascript表单验证汇总
2020/07/20 Javascript
写给小白的JavaScript引擎指南
2015/12/04 Javascript
Vuejs第十三篇之组件——杂项
2016/09/09 Javascript
Bootstrap时间选择器datetimepicker和daterangepicker使用实例解析
2016/09/17 Javascript
jQuery Validate 校验多个相同name的方法
2017/05/18 jQuery
利用webstrom调试Vue.js单页面程序的方法教程
2017/06/06 Javascript
ReactNative 之FlatList使用及踩坑封装总结
2017/11/29 Javascript
Nuxt配合Node在实际生产中的应用详解
2018/08/07 Javascript
javascript写一个ajax自动拦截并下载数据代码实例
2019/09/07 Javascript
vue计算属性+vue中class与style绑定(推荐)
2020/03/30 Javascript
python创建只读属性对象的方法(ReadOnlyObject)
2013/02/10 Python
进一步探究Python中的正则表达式
2015/04/28 Python
动态规划之矩阵连乘问题Python实现方法
2017/11/27 Python
numpy添加新的维度:newaxis的方法
2018/08/02 Python
python顺序执行多个py文件的方法
2019/06/29 Python
pytorch:torch.mm()和torch.matmul()的使用
2019/12/27 Python
python Plotly绘图工具的简单使用
2020/03/03 Python
Python logging模块异步线程写日志实现过程解析
2020/06/30 Python
世界上最大的曲棍球商店:Pro Hockey Life
2017/10/30 全球购物
医护人员英文求职信范文
2013/11/26 职场文书
银行实习生的自我评价
2013/12/09 职场文书
大学生学业生涯规划
2014/01/05 职场文书
个性发展自我评价
2014/02/11 职场文书
销售主管岗位职责范本
2014/02/14 职场文书
动物科学专业求职信
2014/07/27 职场文书
大学生交通专业求职信
2014/09/01 职场文书
2014年科室工作总结
2014/11/20 职场文书
拾金不昧表扬信
2015/01/16 职场文书
MySQL 执行数据库更新update操作的时候数据库卡死了
2022/05/02 MySQL