pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
介绍Python中几个常用的类方法
Apr 08 Python
Django 添加静态文件的两种实现方法(必看篇)
Jul 14 Python
简单学习Python多进程Multiprocessing
Aug 29 Python
python OpenCV学习笔记实现二维直方图
Feb 08 Python
python 对象和json互相转换方法
Mar 22 Python
django 开发忘记密码通过邮箱找回功能示例
Apr 17 Python
python 通过类中一个方法获取另一个方法变量的实例
Jan 22 Python
python使用递归的方式建立二叉树
Jul 03 Python
python中matplotlib条件背景颜色的实现
Sep 02 Python
Python django框架输入汉字,数字,字符生成二维码实现详解
Sep 24 Python
用基于python的appium爬取b站直播消费记录
Apr 17 Python
Python中itertools库的四个函数介绍
Apr 06 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
PHP产生随机字符串函数
2006/12/06 PHP
php strtotime 函数UNIX时间戳
2009/01/14 PHP
PHP使用in_array函数检查数组中是否存在某个值
2015/03/25 PHP
Prototype 1.5.0_rc1 及 Prototype 1.5.0 Pre0小抄本
2006/09/22 Javascript
BOOTSTRAP时间控件显示在模态框下面的bug修复
2015/02/05 Javascript
js实现仿百度瀑布流的方法
2015/02/05 Javascript
javascript计时器详解
2015/02/28 Javascript
windows下安装nodejs及框架express
2015/08/07 NodeJs
javascript倒计时效果实现
2015/11/12 Javascript
Bootstrap选项卡与Masonry插件的完美结合
2016/07/06 Javascript
jQuery和CSS仿京东仿淘宝列表导航菜单
2017/01/04 Javascript
Vue的elementUI实现自定义主题方法
2018/02/23 Javascript
JS实现的JSON数组去重算法示例
2018/04/11 Javascript
JS封装的模仿qq右下角消息弹窗功能示例
2018/08/22 Javascript
JavaScript自动生成 年月范围 选择功能完整示例【基于jQuery插件】
2019/09/03 jQuery
JS实现音乐导航特效
2020/01/06 Javascript
Vue父子传递实例讲解
2020/02/14 Javascript
浅谈JavaScript中你可能不知道URL构造函数的属性
2020/07/13 Javascript
Javascript call及apply应用场景及实例
2020/08/26 Javascript
Python2.5/2.6实用教程 入门基础篇
2009/11/29 Python
Tornado Web服务器多进程启动的2个方法
2014/08/04 Python
Python 内置函数memoryview(obj)的具体用法
2017/11/23 Python
python3中函数参数的四种简单用法
2018/07/09 Python
详解python中__name__的意义以及作用
2019/08/07 Python
python word转pdf代码实例
2019/08/16 Python
详解HTML5中的标签
2015/06/19 HTML / CSS
详解canvas在圆弧周围绘制文本的两种写法
2018/05/22 HTML / CSS
Engel & Bengel官网:婴儿推车、儿童房家具和婴儿设备
2019/12/28 全球购物
SQL Server提供的3种恢复模型都是什么? 有什么区别?
2012/05/13 面试题
红旗团支部事迹材料
2014/01/27 职场文书
优秀医生事迹材料
2014/02/12 职场文书
校庆标语集锦
2014/06/25 职场文书
党在我心中演讲稿
2014/09/02 职场文书
个人年终总结开头
2015/03/06 职场文书
运动会运动员赞词
2015/07/22 职场文书
apache基于端口创建虚拟主机的示例
2021/04/24 Servers