pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python查找相似单词的方法
Mar 05 Python
Python编程实现删除VC临时文件及Debug目录的方法
Mar 22 Python
对Python中数组的几种使用方法总结
Jun 28 Python
分析python请求数据
Aug 19 Python
python使用pymongo操作mongo的完整步骤
Apr 13 Python
Python+OpenCV+pyQt5录制双目摄像头视频的实例
Jun 28 Python
使用pyhon绘图比较两个手机屏幕大小(实例代码)
Jan 03 Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 Python
如何利用python之wxpy模块玩转微信
Aug 17 Python
Pycharm 设置默认解释器路径和编码格式的操作
Feb 05 Python
浅谈Python列表嵌套字典转化的问题
Apr 07 Python
Python预测分词的实现
Jun 18 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
php 网页播放器用来播放在线视频的代码(自动判断并选择视频文件类型)
2010/06/03 PHP
Codeigniter实现处理用户登录验证后的URL跳转
2014/06/12 PHP
网站防止被刷票的一些思路与方法
2015/01/08 PHP
Yii2.0使用阿里云OSS的SDK上传图片、下载、删除图片示例
2017/09/20 PHP
php面试中关于面向对象的相关问题
2019/02/13 PHP
javascript 动态设置已知select的option的value值的代码
2009/12/16 Javascript
js实现单一html页面两套css切换代码
2013/04/11 Javascript
跨域资源共享 CORS 详解
2016/04/26 Javascript
JS判断form内所有表单是否为空的简单实例
2016/09/09 Javascript
微信小程序使用第三方库Underscore.js步骤详解
2016/09/27 Javascript
利用js查找数组中指定元素并返回该元素的所有索引示例
2017/03/29 Javascript
jQuery Form插件使用详解_动力节点Java学院整理
2017/07/17 jQuery
浅谈mint-ui 填坑之路
2017/11/06 Javascript
vue自定义组件实现双向绑定
2021/01/13 Vue.js
[04:41]2014DOTA2国际邀请赛 Liquid顺利突围晋级正赛
2014/07/09 DOTA
Python实现的Kmeans++算法实例
2014/04/26 Python
在Python中操作时间之tzset()方法的使用教程
2015/05/22 Python
python中根据字符串调用函数的实现方法
2016/06/12 Python
Django ORM框架的定时任务如何使用详解
2017/10/19 Python
python中实现精确的浮点数运算详解
2017/11/02 Python
使用python实现ANN
2017/12/20 Python
Python3实现的旋转矩阵图像算法示例
2019/04/03 Python
Django框架实现分页显示内容的方法详解
2019/05/10 Python
详解Python 4.0 预计推出的新功能
2019/07/26 Python
python 计算积分图和haar特征的实例代码
2019/11/20 Python
Python之Class&Object用法详解
2019/12/25 Python
Ann Taylor官方网站:美国最大的女性产品制造商之一
2016/09/14 全球购物
GIVENCHY纪梵希官方旗舰店:高定彩妆与贵族护肤品
2018/04/16 全球购物
审核会计岗位职责
2013/11/08 职场文书
优秀生推荐信范文
2013/11/28 职场文书
公司年会主持词
2014/03/22 职场文书
演讲稿的写法
2014/05/19 职场文书
经济贸易系毕业生求职信
2014/05/31 职场文书
大学学雷锋活动总结
2014/06/26 职场文书
城市规划应届毕业生自荐信
2014/07/04 职场文书
用python画城市轮播地图
2021/05/28 Python