pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python文档生成工具pydoc使用介绍
Jun 02 Python
python脚本设置系统时间的两种方法
Feb 21 Python
Python解析json文件相关知识学习
Mar 01 Python
WINDOWS 同时安装 python2 python3 后 pip 错误的解决方法
Mar 16 Python
代码分析Python地图坐标转换
Feb 08 Python
详解pandas中MultiIndex和对象实际索引不一致问题
Jul 23 Python
python scrapy爬虫代码及填坑
Aug 12 Python
python django 原生sql 获取数据的例子
Aug 14 Python
docker django无法访问redis容器的解决方法
Aug 21 Python
python获取依赖包和安装依赖包教程
Feb 13 Python
django序列化时使用外键的真实值操作
Jul 15 Python
使用Python解析Chrome浏览器书签的示例
Nov 13 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
PHP中for循环语句的几种变型
2006/11/26 PHP
php生成缩略图填充白边(等比缩略图方案)
2013/12/25 PHP
php绘制一条弧线的方法
2015/01/24 PHP
Laravel 5框架学习之Blade 简介
2015/04/08 PHP
微信公众号判断用户是否已关注php代码解析
2016/06/24 PHP
PHP编译configure时常见错误的总结
2017/08/17 PHP
Gambit vs CL BO3 第三场 2.13
2021/03/10 DOTA
JavaScript 在线压缩和格式化收藏
2009/01/16 Javascript
根据鼠标的位置动态的控制层的位置
2009/11/24 Javascript
jquery 获取标签名(tagName)示例代码
2013/07/11 Javascript
jQuery选择器全面总结
2014/01/06 Javascript
sliderToggle在写jquery的计时器setTimeouter中不生效
2014/05/26 Javascript
javascript实现的简单计时器
2015/07/19 Javascript
jquery实现二级导航下拉菜单效果
2015/12/18 Javascript
javascript实现倒计时跳转页面
2016/01/17 Javascript
Bootstrap基本插件学习笔记之标签切换(17)
2016/12/08 Javascript
详解Vue 非父子组件通信方法(非Vuex)
2017/05/24 Javascript
JS按条件 serialize() 对应标签的使用方法
2017/07/24 Javascript
用js实现每隔一秒刷新时间的实例(含年月日时分秒)
2017/10/25 Javascript
Vue源码学习之关于对Array的数据侦听实现
2019/04/23 Javascript
JS实现处理时间,年月日,星期的公共方法示例
2019/05/31 Javascript
[51:27]LGD vs Liquid 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/19 DOTA
[50:02]完美世界DOTA2联赛循环赛 Magma vs IO BO2第一场 11.01
2020/11/02 DOTA
python中安装Scrapy模块依赖包汇总
2017/07/02 Python
Python WXPY实现微信监控报警功能的代码
2017/10/20 Python
python查看列的唯一值方法
2018/07/17 Python
Python 写了个新型冠状病毒疫情传播模拟程序
2020/02/14 Python
Python内建序列通用操作6种实现方法
2020/03/26 Python
python爬虫 requests-html的使用
2020/11/30 Python
Django中使用Celery的方法步骤
2020/12/07 Python
html5的画布canvas——画出弧线、旋转的图形实例代码+效果图
2013/06/09 HTML / CSS
工程招投标邀请书
2014/01/26 职场文书
2016年圣诞节活动总结范文
2016/04/01 职场文书
读《皮囊》有感:理解是对他人的最大的善举
2019/11/14 职场文书
使用php的mail()函数实现发送邮件功能
2021/06/03 PHP
pandas时间序列之pd.to_datetime()的实现
2022/06/16 Python