Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
vc6编写python扩展的方法分享
Jan 17 Python
Python魔术方法详解
Feb 14 Python
python实现的用于搜索文件并进行内容替换的类实例
Jun 28 Python
数组保存为txt, npy, csv 文件, 数组遍历enumerate的方法
Jul 09 Python
Python操作word常见方法示例【win32com与docx模块】
Jul 17 Python
python版大富翁源代码分享
Nov 19 Python
利用python、tensorflow、opencv、pyqt5实现人脸实时签到系统
Sep 25 Python
python实现的按要求生成手机号功能示例
Oct 08 Python
python框架flask入门之环境搭建及开启调试
Jun 07 Python
Win10下用Anaconda安装TensorFlow(图文教程)
Jun 18 Python
QT5 Designer 打不开的问题及解决方法
Aug 20 Python
Python实现FTP文件定时自动下载的步骤
Dec 19 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
浅谈电磁辐射对健康的影响
2021/03/01 无线电
php下过滤html代码的函数 提高程序安全性
2010/03/02 PHP
Yii配置文件用法详解
2014/12/04 PHP
PHP 中TP5 Request 请求对象的实例详解
2017/07/31 PHP
javascript tips提示框组件实现代码
2010/11/19 Javascript
JavaScript入门之对象与JSON详解
2011/10/21 Javascript
JavaScript中点击事件的写法
2016/06/28 Javascript
jquery datatable服务端分页
2016/08/31 Javascript
用jQuery.ajaxSetup实现对请求和响应数据的过滤
2016/12/20 Javascript
bootstrap选项卡使用方法解析
2017/01/11 Javascript
原生js轮播(仿慕课网)
2017/02/15 Javascript
Node.js服务器开启Gzip压缩教程
2017/08/11 Javascript
react项目实践之webpack-dev-serve
2018/09/14 Javascript
CountUp.js数字滚动插件使用方法详解
2019/10/17 Javascript
JavaScript实现公告栏上下滚动效果
2020/03/13 Javascript
vue实现导航菜单和编辑文本的示例代码
2020/07/04 Javascript
打开电脑上的QQ的python代码
2013/02/10 Python
用Python的线程来解决生产者消费问题的示例
2015/04/02 Python
在Python中使用判断语句和循环的教程
2015/04/25 Python
浅析python递归函数和河内塔问题
2017/04/18 Python
Python Flask前后端Ajax交互的方法示例
2018/07/31 Python
关于numpy中eye和identity的区别详解
2019/11/29 Python
python+opencv3.4.0 实现HOG+SVM行人检测的示例代码
2021/01/28 Python
利用CSS3实现平移动画效果示例代码
2016/10/12 HTML / CSS
Html5上传图片 移动端、PC端通用代码
2016/06/08 HTML / CSS
瑞士网球商店:Tennis-Point
2020/03/12 全球购物
新锐科技Java程序员面试题
2016/07/25 面试题
日语专业毕业生自荐信
2013/11/11 职场文书
省三好学生申请材料
2014/01/22 职场文书
幼儿教师国培感言
2014/02/19 职场文书
农业开发项目建议书
2014/05/16 职场文书
党的群众路线教育实践活动查摆问题自查报告
2014/10/10 职场文书
惊涛骇浪观后感
2015/06/05 职场文书
2015秋季开学典礼新闻稿
2015/07/17 职场文书
运动会跳远广播稿
2015/08/19 职场文书
高一语文教学反思
2016/02/16 职场文书