Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python中的join()方法的使用
May 19 Python
Python线性回归实战分析
Feb 01 Python
Tensorflow分类器项目自定义数据读入的实现
Feb 05 Python
python 缺失值处理的方法(Imputation)
Jul 02 Python
python读写csv文件的方法
Aug 13 Python
在python中利用dict转json按输入顺序输出内容方式
Feb 27 Python
Mysql数据库反向生成Django里面的models指令方式
May 18 Python
对Keras中predict()方法和predict_classes()方法的区别说明
Jun 09 Python
Python select及selectors模块概念用法详解
Jun 22 Python
python3.x中安装web.py步骤方法
Jun 23 Python
用Python制作灯光秀短视频的思路详解
Apr 13 Python
pytorch分类模型绘制混淆矩阵以及可视化详解
Apr 07 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
thinkPHP中create方法与令牌验证实例浅析
2015/12/08 PHP
WordPress中用于更新伪静态规则的PHP代码实例讲解
2015/12/18 PHP
Laravel 模型关联基础教程详解
2019/09/17 PHP
JS实现标签页效果(配合css)
2013/04/03 Javascript
js使用Array.prototype.sort()对数组对象排序的方法
2015/01/28 Javascript
Vue.js 插件开发详解
2017/03/29 Javascript
easyui-edatagrid.js实现回车键结束编辑功能的实例
2017/04/12 Javascript
详解react-router如何实现按需加载
2017/06/15 Javascript
JQuery判断正整数整理小结
2017/08/21 jQuery
Vue.js 2.0和Cordova开发webApp环境搭建方法
2018/02/26 Javascript
利用npm 安装删除模块的方法
2018/05/15 Javascript
vue中使用props传值的方法
2019/05/08 Javascript
js获取浏览器地址(获取第1个斜杠后的内容)
2019/09/03 Javascript
js实现课堂随机点名系统
2019/11/21 Javascript
Vue解决echart在element的tab切换时显示不正确问题
2020/08/03 Javascript
使用python实现递归版汉诺塔示例(汉诺塔递归算法)
2014/04/08 Python
Python 数据结构之队列的实现
2017/01/22 Python
python SSH模块登录,远程机执行shell命令实例解析
2018/01/12 Python
Python3利用Dlib实现摄像头实时人脸检测和平铺显示示例
2019/02/21 Python
python pandas生成时间列表
2019/06/29 Python
django 框架实现的用户注册、登录、退出功能示例
2019/11/28 Python
python模式 工厂模式原理及实例详解
2020/02/11 Python
适合Python初学者的一些编程技巧
2020/02/12 Python
PyInstaller将Python文件打包为exe后如何反编译(破解源码)以及防止反编译
2020/04/15 Python
俄罗斯的精英皮具:Wittchen
2018/01/29 全球购物
美国性感内衣店:Yandy
2018/06/12 全球购物
网上卖盒饭创业计划书
2014/01/26 职场文书
中专生自我鉴定范文
2014/02/02 职场文书
外语专业毕业生自荐信
2014/04/14 职场文书
大学生演讲稿
2014/04/25 职场文书
个人四风问题整改措施思想汇报
2014/10/04 职场文书
基层党组织整改方案
2014/10/25 职场文书
暑假安全保证书
2015/02/28 职场文书
浅谈Web Storage API的使用
2021/06/23 Javascript
Django Paginator分页器的使用示例
2021/06/23 Python
CSS3实现指纹特效代码
2022/03/17 HTML / CSS