Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python浅拷贝与深拷贝用法实例
May 09 Python
Python爬虫框架Scrapy实战之批量抓取招聘信息
Aug 07 Python
Python中eval带来的潜在风险代码分析
Dec 11 Python
Python字符串的一些操作方法总结
Jun 10 Python
python命令行参数用法实例分析
Jun 25 Python
Python测试模块doctest使用解析
Aug 10 Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 Python
pandas read_excel()和to_excel()函数解析
Sep 19 Python
Django文件上传与下载(FileFlid)
Oct 06 Python
PYTHON绘制雷达图代码实例
Oct 15 Python
使用Python+selenium实现第一个自动化测试脚本
Mar 17 Python
Python 操作 PostgreSQL 数据库示例【连接、增删改查等】
Apr 21 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
destoon会员注册提示“数据校验失败(2)”解决方法
2014/06/21 PHP
php自定义错误处理用法实例
2015/03/20 PHP
PHP开发中AJAX技术的简单应用
2015/12/11 PHP
Thinkphp5.0自动生成模块及目录的方法详解
2017/04/17 PHP
微信开发之获取JSAPI TICKET
2017/07/07 PHP
FireFox与IE 下js兼容触发click事件的代码
2008/11/20 Javascript
jQuery 位置函数offset,innerWidth,innerHeight,outerWidth,outerHeight,scrollTop,scrollLeft
2010/03/23 Javascript
js 内存释放问题
2010/04/25 Javascript
Jquery中Ajax 缓存带来的影响的解决方法
2011/05/19 Javascript
Js+Flash实现访问剪切板操作
2012/11/20 Javascript
Extjs显示从数据库取出时间转换JSON后的出现问题
2012/11/20 Javascript
JavaScript创建类/对象的几种方式概述及实例
2013/05/06 Javascript
js 控制图片大小核心讲解
2013/10/09 Javascript
JavaScript初学者建议:不要去管浏览器兼容
2014/02/04 Javascript
javascript异步编程的4种方法
2014/02/19 Javascript
JQuery中使用ajax传输超大数据的解决方法
2014/07/14 Javascript
javascript trim函数在IE下不能用的解决方法
2014/09/12 Javascript
JS图片压缩(pc端和移动端都适用)
2017/01/12 Javascript
利用Javascript获取选择文本所在的句子详解
2017/12/03 Javascript
性能优化篇之Webpack构建速度优化的建议
2019/04/03 Javascript
微信小程序获取公众号文章列表及显示文章的示例代码
2020/03/10 Javascript
浅谈vue项目利用Hbuilder打包成APP流程,以及遇到的坑
2020/09/12 Javascript
[01:09:50]VP vs Pain 2018国际邀请赛小组赛BO2 第二场
2018/08/20 DOTA
Python下rrdtool模块的基本使用方法
2015/11/13 Python
python 将list转成字符串,中间用符号分隔的方法
2018/10/23 Python
python+ffmpeg批量去视频开头的方法
2019/01/09 Python
使用Django简单编写一个XSS平台的方法步骤
2019/03/25 Python
python操作小程序云数据库实现简单的增删改查功能
2019/06/06 Python
python画图把时间作为横坐标的方法
2019/07/07 Python
浅谈Pytorch中的torch.gather函数的含义
2019/08/18 Python
python用类实现文章敏感词的过滤方法示例
2019/10/27 Python
Python 实现Serial 与STM32J进行串口通讯
2019/12/18 Python
世界上最值得信赖的多日游在线市场:TourRadar
2018/07/20 全球购物
医学生个人求职信范文
2013/09/24 职场文书
好好学习保证书
2015/02/26 职场文书
劳动仲裁代理词范文
2015/05/25 职场文书