Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python thread 并发且顺序运行示例
Apr 09 Python
Python中__init__和__new__的区别详解
Jul 09 Python
Python实现把utf-8格式的文件转换成gbk格式的文件
Jan 22 Python
浅谈Python 对象内存占用
Jul 15 Python
Python中用post、get方式提交数据的方法示例
Sep 22 Python
python发送邮件脚本
May 22 Python
pyqt5 实现在别的窗口弹出进度条
Jun 18 Python
python获取当前文件路径以及父文件路径的方法
Jul 10 Python
Python多线程获取返回值代码实例
Feb 17 Python
解决Python图形界面中设置尺寸的问题
Mar 05 Python
python+requests实现接口测试的完整步骤
Oct 27 Python
Python天气语音播报小助手
Sep 25 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
不用GD库生成当前时间的PNG格式图象的程序
2006/10/09 PHP
PHP函数eval()介绍和使用示例
2014/08/20 PHP
Thinkphp批量更新数据的方法汇总
2016/06/29 PHP
ThinkPHP实现转换数据库查询结果数据到对应类型的方法
2017/11/16 PHP
PHP获取当前时间不准确问题解决方案
2020/08/14 PHP
javascript中万恶的function实例分析
2011/05/25 Javascript
文本框获得焦点和失去焦点的判断代码
2012/03/18 Javascript
固定网页背景图同时保持图片比例的思路代码
2013/08/15 Javascript
JS正则表达式大全(整理详细且实用)
2013/11/14 Javascript
js加载读取内容及显示与隐藏div示例
2014/02/13 Javascript
使用js实现的简单拖拽效果
2015/03/18 Javascript
JavaScript操作XML文件之XML读取方法
2015/06/09 Javascript
JS仿hao123导航页面图片轮播效果
2016/09/01 Javascript
jQuery替换节点用法示例(使用replaceWith方法)
2016/09/08 Javascript
基于jQuery解决ios10以上版本缩放问题
2017/11/03 jQuery
vue-自定义组件传值的实例讲解
2018/09/18 Javascript
Vue监听一个数组id是否与另一个数组id相同的方法
2018/09/26 Javascript
Vue路由守卫之路由独享守卫
2019/09/25 Javascript
ES5 模拟 ES6 的 Symbol 实现私有成员功能示例
2020/05/06 Javascript
vue3为什么要用proxy替代defineProperty
2020/10/19 Javascript
[03:52]DOTA2英雄基础教程 酒仙
2013/12/23 DOTA
python封装对象实现时间效果
2020/04/23 Python
python使用递归解决全排列数字示例
2014/02/11 Python
Python实现朴素贝叶斯分类器的方法详解
2018/07/04 Python
pygame游戏之旅 添加键盘按键的方法
2018/11/20 Python
使用python制作一个为hex文件增加版本号的脚本实例
2019/06/12 Python
python被修饰的函数消失问题解决(基于wraps函数)
2019/11/04 Python
详解pycharm的python包opencv(cv2)无代码提示问题的解决
2021/01/29 Python
DVF官方网站:美国时装界尊尚品牌
2017/08/29 全球购物
Bath & Body Works阿联酋:在线购买沐浴和身体用品
2021/02/27 全球购物
大学生求职简历的自我评价范文
2013/10/12 职场文书
开办大学饮食联盟创业计划书
2014/01/29 职场文书
大学应届生的自我评价
2014/03/06 职场文书
厨师长岗位职责范本
2014/08/25 职场文书
Python实战之实现简易的学生选课系统
2021/05/25 Python
python周期任务调度工具Schedule使用详解
2021/11/23 Python