pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析文件示例
Jan 23 Python
利用Pandas 创建空的DataFrame方法
Apr 08 Python
Python基于TCP实现会聊天的小机器人功能示例
Apr 09 Python
python绘制直方图和密度图的实例
Jul 08 Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 Python
python global和nonlocal用法解析
Feb 03 Python
python如何实现单链表的反转
Feb 10 Python
Python将二维列表list的数据输出(TXT,Excel)
Apr 23 Python
python 负数取模运算实例
Jun 03 Python
详解分布式系统中如何用python实现Paxos
May 18 Python
基于Python实现将列表数据生成折线图
Mar 23 Python
Python+pyaudio实现音频控制示例详解
Jul 23 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
PHP中is_file不能替代file_exists的理由
2014/03/04 PHP
PHP使用header()输出图片缓存实例
2014/12/09 PHP
图片按比例缩放函数
2006/06/26 Javascript
解决 firefox 不支持 document.all的方法
2007/03/12 Javascript
ExtJS 学习专题(一) 如何应用ExtJS(附实例)
2010/03/11 Javascript
javascript 用记忆函数快速计算递归函数
2010/03/15 Javascript
游览器中javascript的执行过程(图文)
2012/05/20 Javascript
jQuery中insertAfter()方法用法实例
2015/01/08 Javascript
Javascript封装id、class与元素选择器方法示例
2017/03/13 Javascript
jQuery实现简单日期格式化功能示例
2017/09/19 jQuery
JS非行间样式获取函数的实例代码
2018/06/05 Javascript
使用express来代理服务的方法
2019/06/21 Javascript
vue中使用element ui的弹窗与echarts之间的问题详解
2019/10/25 Javascript
JS实现滑动拼图验证功能完整示例
2020/03/29 Javascript
js this 绑定机制深入详解
2020/04/30 Javascript
如何在node环境实现“get数据解析”代码实例
2020/07/03 Javascript
UEditor 自定义图片视频尺寸校验功能的实现代码
2020/10/20 Javascript
[01:55]2014DOTA2国际邀请赛快报:国土生病 紧急去医院治疗
2014/07/10 DOTA
[02:17]快乐加倍!DOTA2食人魔魔法师至宝+迎霜节活动上线
2019/12/22 DOTA
Python 搭建Web站点之Web服务器与Web框架
2016/11/06 Python
Python装饰器用法实例总结
2018/02/07 Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
2018/04/28 Python
Flask核心机制之上下文源码剖析
2018/12/25 Python
Django使用redis缓存服务器的实现代码示例
2019/04/28 Python
Python 继承,重写,super()调用父类方法操作示例
2019/09/29 Python
使用python从三个角度解决josephus问题的方法
2020/03/27 Python
Python存储读取HDF5文件代码解析
2020/11/25 Python
c/c++某大公司的两道笔试题
2014/02/02 面试题
工作中的自我评价如何写好
2013/10/28 职场文书
音乐教育感言
2014/03/05 职场文书
校园学雷锋活动月总结
2014/03/09 职场文书
《黄山奇石》教学反思
2014/04/19 职场文书
高中生学习计划书
2014/09/15 职场文书
2014年乡镇工会工作总结
2014/12/02 职场文书
2015年国税春训心得体会
2015/03/09 职场文书
2015年小学图书室工作总结
2015/05/18 职场文书