Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pyqt和pyside开发图形化界面
Jan 22 Python
Python获取当前公网ip并自动断开宽带连接实例代码
Jan 12 Python
Python 使用with上下文实现计时功能
Mar 09 Python
对Python中内置异常层次结构详解
Oct 18 Python
Python爬取破解无线网络wifi密码过程解析
Sep 17 Python
python next()和iter()函数原理解析
Feb 07 Python
python GUI库图形界面开发之PyQt5 MDI(多文档窗口)QMidArea详细使用方法与实例
Mar 05 Python
Python3.7 读取音频根据文件名生成脚本的代码
Apr 07 Python
python实现密码验证合格程序的思路详解
Jun 01 Python
python的launcher用法知识点总结
Aug 07 Python
python 操作excel表格的方法
Dec 05 Python
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
WINXP下apache+php4+mysql
2006/11/25 PHP
php addslashes及其他清除空格的方法是不安全的
2012/01/25 PHP
与文件上传有关的php配置参数总结
2013/06/14 PHP
PHP实现XML与数据格式进行转换类实例
2015/07/29 PHP
PHP提高编程效率的20个要点
2015/09/23 PHP
PHP新建类问题分析及解决思路
2015/11/19 PHP
PHP中类型转换 ,常量,系统常量,魔术常量的详解
2017/10/26 PHP
Laravel使用消息队列需要注意的一些问题
2017/12/13 PHP
Laravel数据库读写分离配置的方法
2019/10/13 PHP
jquery实现在页面加载的时自动为日期插件添加当前日期
2014/08/20 Javascript
JavaScript实现的石头剪刀布游戏源码分享
2014/08/22 Javascript
jQuery判断元素是否显示 是否隐藏的简单实现代码
2016/05/19 Javascript
微信小程序实现多个按钮toggle功能的实例
2017/06/13 Javascript
基于jQuery封装的分页组件
2017/06/26 jQuery
微信小程序之判断页面滚动方向的示例代码
2018/08/30 Javascript
基于Koa2写个脚手架模拟接口服务的方法
2018/11/27 Javascript
Node.js一行代码实现静态文件服务器的方法步骤
2019/05/07 Javascript
webpack中如何加载静态文件的方法步骤
2019/05/18 Javascript
Bootstrap简单实用的表单验证插件BootstrapValidator用法实例详解
2020/03/29 Javascript
基于redis的小程序登录实现方法流程分析
2020/05/25 Javascript
[02:28]DOTA2英雄基础教程 灰烬之灵
2013/12/19 DOTA
[00:09]DOTA2全国高校联赛 精彩活动引爆全场
2018/05/30 DOTA
python3使用urllib模块制作网络爬虫
2016/04/08 Python
为什么选择python编程语言入门黑客攻防 给你几个理由!
2018/02/02 Python
Python Numpy 数组的初始化和基本操作
2018/03/13 Python
python tkinter界面居中显示的方法
2018/10/11 Python
python设置环境变量的作用和实例
2019/07/09 Python
浅谈keras保存模型中的save()和save_weights()区别
2020/05/21 Python
Python实现疫情地图可视化
2021/02/05 Python
销售类个人求职信范文
2013/09/25 职场文书
团日活动总结报告
2014/06/25 职场文书
单位领导婚礼致辞
2015/07/28 职场文书
教务处干事工作总结
2015/08/14 职场文书
暑假开始了,你的暑假学习计划写好了吗?
2019/07/04 职场文书
导游词之阆中古城
2019/12/23 职场文书
解决vue中provide inject的响应式监听
2022/04/19 Vue.js