pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现超简单端口转发的方法
Mar 13 Python
Python中计算三角函数之cos()方法的使用简介
May 15 Python
python 生成图形验证码的方法示例
Nov 11 Python
Python中extend和append的区别讲解
Jan 24 Python
PyQt5实现五子棋游戏(人机对弈)
Mar 24 Python
Python向excel中写入数据的方法
May 05 Python
如何基于python生成list的所有的子集
Nov 11 Python
Django自带的加密算法及加密模块详解
Dec 03 Python
3分钟看懂Python后端必须知道的Django的信号机制
Jul 26 Python
pycharm 添加解释器的方法步骤
Aug 31 Python
如何解决python多种版本冲突问题
Oct 13 Python
Python实现列表拼接和去重的三种方式
Jul 02 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
php 表单提交大量数据发生丢失的解决方法
2014/03/03 PHP
php调用nginx的mod_zip模块打包ZIP文件
2014/06/11 PHP
PHP输入流php://input实例讲解
2015/12/22 PHP
PHP实现超简单的SSL加密解密、验证及签名的方法示例
2017/08/28 PHP
PHP基于关联数组20行代码搞定约瑟夫问题示例
2017/11/07 PHP
php快速导入大量数据的实例方法
2019/09/23 PHP
JavaScript操作XML实例代码(获取新闻标题并分页,并分页)
2010/05/25 Javascript
25个好玩的JavaScript小游戏分享
2011/04/22 Javascript
js 定时器setTimeout无法调用局部变量的解决办法
2013/11/28 Javascript
table对象中的insertRow与deleteRow使用示例
2014/01/26 Javascript
JS建造者模式基本用法实例分析
2015/06/30 Javascript
jQuery图片左右滚动代码 有左右按钮实例
2016/06/20 Javascript
JS实现JSON.stringify的实例代码讲解
2017/02/07 Javascript
jQuery鼠标悬停内容动画切换效果
2017/04/27 jQuery
原生JS+Canvas实现五子棋游戏实例
2017/06/19 Javascript
详解EasyUi控件中的Datagrid
2017/08/23 Javascript
JavaScript设计模式之缓存代理模式原理与简单用法示例
2018/08/07 Javascript
关于微信公众号开发无法支付的问题解决
2018/12/28 Javascript
[43:32]2014 DOTA2华西杯精英邀请赛 5 25 LGD VS NewBee第一场
2014/05/26 DOTA
Python中函数的多种格式和使用实例及小技巧
2015/04/13 Python
在Python中操作字典之setdefault()方法的使用
2015/05/21 Python
深入解析Python中的lambda表达式的用法
2015/08/28 Python
Python判断变量是否为Json格式的字符串示例
2017/05/03 Python
解决Pycharm界面的子窗口不见了的问题
2019/01/17 Python
Python3.5文件修改操作实例分析
2019/05/01 Python
python实现串口自动触发工作的示例
2019/07/02 Python
Pytorch Tensor的统计属性实例讲解
2019/12/30 Python
CSS3 伪类选择器 nth-child()说明
2010/07/10 HTML / CSS
中国旅游网站:途牛旅游网
2019/09/29 全球购物
俄罗斯品牌服装和鞋子在线商店:BRIONITY
2020/03/26 全球购物
计算机学生求职信范文
2014/01/30 职场文书
公司门卫岗位职责
2014/03/15 职场文书
任命书范本大全
2014/06/06 职场文书
市委常委班子党的群众路线教育实践活动整改措施
2014/10/02 职场文书
导游词之山东八大关
2019/12/18 职场文书
python数据分析之单因素分析线性拟合及地理编码
2022/06/25 Python