pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的框架中为MySQL实现restful接口的教程
Apr 08 Python
python实现从字典中删除元素的方法
May 04 Python
python使用xmlrpclib模块实现对百度google的ping功能
Jun 02 Python
Python的Flask框架中的Jinja2模板引擎学习教程
Jun 30 Python
python和flask中返回JSON数据的方法
Mar 26 Python
python3.6.3+opencv3.3.0实现动态人脸捕获
May 25 Python
浅析python中的迭代与迭代对象
Oct 08 Python
Python2和Python3中urllib库中urlencode的使用注意事项
Nov 26 Python
python 检查是否为中文字符串的方法
Dec 28 Python
Django中的FBV和CBV用法详解
Sep 15 Python
python如何查看安装了的模块
Jun 23 Python
python快速安装OpenCV的步骤记录
Feb 22 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
Cannot modify header information错误解决方法
2008/10/08 PHP
php下载excel无法打开的解决方法
2013/12/24 PHP
phpword插件导出word文件时中文乱码问题处理方案
2014/08/19 PHP
php检测url是否存在的方法
2015/04/14 PHP
常见的四种POST 提交数据方式(小总结)
2015/10/08 PHP
PHP利用百度ai实现文本和图片审核
2019/05/08 PHP
IE iframe的onload方法分析小结
2010/01/07 Javascript
javascript天然的迭代器
2010/10/29 Javascript
浅析Js(Jquery)中,字符串与JSON格式互相转换的示例(直接运行实例)
2013/07/09 Javascript
js控制不同的时间段显示不同的css样式的实例代码
2013/11/04 Javascript
jquery实现背景墙聚光灯效果示例分享
2014/03/02 Javascript
javascript实现延时显示提示框特效代码
2016/04/27 Javascript
字符串反转_JavaScript
2016/04/28 Javascript
初探nodeJS
2017/01/24 NodeJs
判断横屏竖屏(三种)
2017/02/13 Javascript
详解使用angular框架离线你的应用(pwa指南)
2019/01/31 Javascript
layui使用templet格式化表格数据的方法
2019/09/16 Javascript
JQuery中DOM节点的操作与访问方法实例分析
2019/12/23 jQuery
JS面向对象之多选框实现
2020/01/17 Javascript
区分vue-router的hash和history模式
2020/10/03 Javascript
[08:07]DOTA2每周TOP10 精彩击杀集锦vol.8
2014/06/25 DOTA
python网络编程之UDP通信实例(含服务器端、客户端、UDP广播例子)
2014/04/25 Python
python基于urllib实现按照百度音乐分类下载mp3的方法
2015/05/25 Python
Django中使用group_by的方法
2015/05/26 Python
bpython 功能强大的Python shell
2016/02/16 Python
如何运行.ipynb文件的图文讲解
2019/06/27 Python
python射线法判断一个点在图形区域内外
2019/06/28 Python
wxpython多线程防假死与线程间传递消息实例详解
2019/12/13 Python
python构造IP报文实例
2020/05/05 Python
New Balance美国官网:运动鞋和健身服装
2017/04/11 全球购物
美国知名生活购物网站:Goop
2017/11/03 全球购物
实习自我鉴定模板
2013/09/28 职场文书
车间主管岗位职责
2013/11/14 职场文书
计算机专业毕业生自我鉴定
2014/01/16 职场文书
致全体运动员广播稿
2014/02/01 职场文书
阳光体育运动标语口号
2015/12/26 职场文书