pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解python中的闭包和装饰器
Jun 12 Python
Python下调用Linux的Shell命令的方法
Jun 12 Python
python中int与str互转方法
Jul 02 Python
Python编程深度学习计算库之numpy
Dec 28 Python
Python可迭代对象操作示例
May 07 Python
selenium处理元素定位点击无效问题
Jun 12 Python
使用turtle绘制五角星、分形树
Oct 06 Python
python sorted方法和列表使用解析
Nov 18 Python
Python面向对象封装操作案例详解 II
Jan 02 Python
python实现猜数游戏
Mar 27 Python
解决pytorch读取自制数据集出现过的问题
May 31 Python
Python中的pprint模块
Nov 27 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
利用static实现表格的颜色隔行显示
2006/10/09 PHP
PHP Google的translate API代码
2008/12/10 PHP
Fatal error: Call to undefined function curl_init()解决方法
2010/04/09 PHP
php中使用explode查找某个字符是否存在的方法
2011/07/12 PHP
yii2中的rules 自定义验证规则详解
2016/04/19 PHP
PHP Filter过滤器全面解析
2016/08/09 PHP
php 输入输出流详解及示例代码
2016/08/25 PHP
通过正则表达式实现表单验证是否为中文
2014/02/18 Javascript
javascript实现表单提交后,提交按钮不可用的方法
2015/04/18 Javascript
JavaScript中的对象继承关系
2016/08/01 Javascript
JavaScript原型继承_动力节点Java学院整理
2017/06/30 Javascript
bootstrap table sum总数量统计实现方法
2017/10/29 Javascript
vue-router 组件复用问题详解
2018/01/22 Javascript
vue实现裁切图片同时实现放大、缩小、旋转功能
2018/03/02 Javascript
Bootstrap Table实现定时刷新数据的方法
2018/08/13 Javascript
vue 更改连接后台的api示例
2019/11/11 Javascript
解决vue单页面多个组件嵌套监听浏览器窗口变化问题
2020/07/30 Javascript
Vue中computed和watch有哪些区别
2020/12/19 Vue.js
vue-cli 3如何使用vue-bootstrap-datetimepicker日期插件
2021/02/20 Vue.js
[53:15]Newbee vs Pain 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
[38:31]完美世界DOTA2联赛PWL S3 Magma vs GXR 第一场 12.13
2020/12/17 DOTA
用Python制作简单的朴素基数估计器的教程
2015/04/01 Python
Python实现的数据结构与算法之队列详解
2015/04/22 Python
Python get获取页面cookie代码实例
2018/09/12 Python
Django 日志配置按日期滚动的方法
2019/01/31 Python
Windows平台Python编程必会模块之pywin32介绍
2019/10/01 Python
python3处理word文档实例分析
2020/12/01 Python
如何用canvas实现在线签名的示例代码
2018/07/10 HTML / CSS
Desigual英国官网:在线购买原创服装
2018/03/09 全球购物
介绍一下Java中标识符的命名规则
2014/02/03 面试题
新学期开学寄语
2014/01/18 职场文书
《小猫刮胡子》教学反思
2014/02/21 职场文书
运动会跳远广播稿5篇
2014/09/17 职场文书
英语辞职信范文
2015/02/28 职场文书
升学宴学生致辞
2015/07/27 职场文书
Python面向对象编程之类的概念
2021/11/01 Python