Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
如何用Python制作微信好友个性签名词云图
Jun 28 Python
Pandas聚合运算和分组运算的实现示例
Oct 17 Python
简单了解Python3 bytes和str类型的区别和联系
Dec 19 Python
Python实现子类调用父类的初始化实例
Mar 12 Python
解决jupyter notebook import error但是命令提示符import正常的问题
Apr 15 Python
Python坐标轴操作及设置代码实例
Jun 04 Python
Python基于httpx模块实现发送请求
Jul 07 Python
pandas apply多线程实现代码
Aug 17 Python
Python unittest生成测试报告过程解析
Sep 08 Python
python软件测试Jmeter性能测试JDBC Request(结合数据库)的使用详解
Jan 26 Python
pycharm配置python 设置pip安装源为豆瓣源
Feb 05 Python
如何在Python中创建二叉树
Mar 30 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
PHP通过session id 实现session共享和登录验证的代码
2012/06/03 PHP
yum命令安装php7和相关扩展
2016/07/04 PHP
Laravel学习基础之migrate的使用教程
2017/10/11 PHP
php解压缩zip和rar压缩包文件的方法
2019/07/10 PHP
javascript 读取图片文件的大小
2009/06/25 Javascript
javascript 打印内容方法小结
2009/11/04 Javascript
Ext grid 添加右击菜单
2009/11/26 Javascript
基于jQuery的固定表格头部的代码(IE6,7,8测试通过)
2010/05/18 Javascript
Jquery取得iframe下内容的方法
2013/11/18 Javascript
详解jQuery的表单验证插件--Validation
2016/12/21 Javascript
JS字符串按逗号和回车分隔的方法
2017/04/25 Javascript
Node.js实现发送邮件功能
2017/11/06 Javascript
vue实现图片滚动的示例代码(类似走马灯效果)
2018/03/03 Javascript
js中Object.defineProperty()方法的不详解
2018/07/09 Javascript
laydate如何根据开始时间或者结束时间限制范围
2018/11/15 Javascript
使用vue-router切换页面时,获取上一页url以及当前页面url的方法
2019/05/06 Javascript
用Vue.js在浏览器中实现裁剪图像功能
2019/06/18 Javascript
js模拟F11页面全屏显示
2019/09/17 Javascript
vue实现跳转接口push 转场动画示例
2019/11/01 Javascript
JavaScript实现Excel表格效果
2020/02/07 Javascript
利用soaplib搭建webservice详细步骤和实例代码
2013/11/20 Python
全面了解Python环境配置及项目建立
2016/06/30 Python
Python字符串逆序输出的实例讲解
2019/02/16 Python
Django框架封装外部函数示例
2019/05/28 Python
python实现列表中最大最小值输出的示例
2019/07/09 Python
django rest framework使用django-filter用法
2020/07/15 Python
用python实现学生管理系统
2020/07/24 Python
澳大利亚现代波西米亚风格女装网站:Bohemian Traders
2018/04/16 全球购物
介绍一下Java中标识符的命名规则
2014/02/03 面试题
房地产员工找工作的自我评价
2013/11/15 职场文书
建筑行业的大学生自我评价
2013/12/08 职场文书
毕业生实习期转正自我鉴定
2014/09/26 职场文书
岳庙导游词
2015/02/04 职场文书
求职自荐信怎么写
2015/03/04 职场文书
再也不用花钱买漫画!Python爬取某漫画的脚本及源码
2021/06/09 Python
SpringBoot项目部署到阿里云服务器的实现步骤
2022/06/28 Java/Android