pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈Python 的枚举 Enum
Jun 12 Python
Python利用递归和walk()遍历目录文件的方法示例
Jul 14 Python
基于python3 类的属性、方法、封装、继承实例讲解
Sep 19 Python
解决Python字典写入文件出行首行有空格的问题
Sep 27 Python
NumPy 如何生成多维数组的方法
Feb 05 Python
利用python将pdf输出为txt的实例讲解
Apr 23 Python
python pandas修改列属性的方法详解
Jun 09 Python
详解分布式任务队列Celery使用说明
Nov 29 Python
matplotlib实现区域颜色填充
Mar 18 Python
Django单元测试中Fixtures的使用方法
Feb 26 Python
django 连接数据库出现1045错误的解决方式
May 14 Python
PyCharm+Miniconda3安装配置教程详解
Feb 16 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
Discuz 6.0+ 批量注册用户名
2009/09/13 PHP
PHP 数据结构队列(SplQueue)和优先队列(SplPriorityQueue)简单使用实例
2015/05/12 PHP
利用“多说”制作留言板、评论系统
2015/07/14 PHP
php版微信支付api.mch.weixin.qq.com域名解析慢原因与解决方法
2016/10/12 PHP
浅谈PHP的exec()函数无返回值排查方法(必看)
2017/03/31 PHP
详细解读php的命名空间(二)
2018/02/21 PHP
浅析PHP中json_encode与json_decode的区别
2020/07/15 PHP
javascript 单选框,多选框美化代码
2008/08/01 Javascript
javascript this用法小结
2008/12/19 Javascript
JQueryiframe页面操作父页面中的元素与方法(实例讲解)
2013/11/19 Javascript
javascript定义变量时带var与不带var的区别分析
2015/01/12 Javascript
AngularJS ng-mousedown 指令
2016/08/02 Javascript
JavaScript中三个等号和两个等号的区别(== 和 ===)浅析
2016/09/22 Javascript
手动初始化Angular的模块与控制器
2016/12/26 Javascript
Angular之toDoList的实现代码示例
2017/12/02 Javascript
解决axios发送post请求返回400状态码的问题
2018/08/11 Javascript
angularjs 动态从后台获取下拉框的值方法
2018/08/13 Javascript
详解使用WebPack搭建React开发环境
2019/08/06 Javascript
JavaScript异步操作的几种常见处理方法实例总结
2020/05/11 Javascript
深入分析jQuery.one() 函数
2020/06/03 jQuery
Vue结合路由配置递归实现菜单栏功能
2020/06/16 Javascript
vant自定义二级菜单操作
2020/11/02 Javascript
JavaScript 防盗链的原理以及破解方法
2020/12/29 Javascript
Python用户推荐系统曼哈顿算法实现完整代码
2017/12/01 Python
python实现自动化上线脚本的示例
2019/07/01 Python
使用Python爬虫库BeautifulSoup遍历文档树并对标签进行操作详解
2020/01/25 Python
Python实现链表反转的方法分析【迭代法与递归法】
2020/02/22 Python
Python 如何实现访问者模式
2020/07/28 Python
PyCharm 2020.2.2 x64 下载并安装的详细教程
2020/10/15 Python
详解python的super()的作用和原理
2020/10/29 Python
python time.strptime格式化实例详解
2021/02/03 Python
HTML5新控件之日期和时间选择输入的实现代码
2018/09/13 HTML / CSS
中专自荐信
2013/10/13 职场文书
大学生写自荐信的技巧
2014/01/08 职场文书
学习全国两会精神心得体会范文
2014/03/17 职场文书
基本公共卫生服务健康教育工作方案
2014/05/22 职场文书