pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python datetime时间格式化去掉前导0
Jul 31 Python
Python实现简单状态框架的方法
Mar 19 Python
python使用三角迭代计算圆周率PI的方法
Mar 20 Python
python正则表达式match和search用法实例
Mar 26 Python
Python urllib、urllib2、httplib抓取网页代码实例
May 09 Python
Python的Django框架可适配的各种数据库介绍
Jul 15 Python
对python:threading.Thread类的使用方法详解
Jan 31 Python
使用Python做垃圾分类的原理及实例代码附源码
Jul 02 Python
详解python中的数据类型和控制流
Aug 08 Python
Python+OpenCV实现旋转文本校正方式
Jan 09 Python
Python标准库itertools的使用方法
Jan 17 Python
django3.02模板中的超链接配置实例代码
Feb 04 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
PHP 得到根目录的 __FILE__ 常量
2008/07/23 PHP
PHP通过header实现文本文件下载的代码
2010/08/08 PHP
laravel框架实现为 Blade 模板引擎添加新文件扩展名操作示例
2020/01/25 PHP
javascript 树控件 比较好用
2009/06/11 Javascript
JS简单的轮播的图片滚动实例
2013/06/17 Javascript
JavaScript的Module模式编程深入分析
2013/08/13 Javascript
跨域传值即主页面与iframe之间互相传值
2013/12/09 Javascript
js无刷新操作table的行和列
2014/03/27 Javascript
jquery操作HTML5 的data-*的用法实例分享
2014/08/17 Javascript
jQuery使用正则表达式限制文本框只能输入数字
2016/06/18 Javascript
Vue.js组件使用开发实例教程
2016/11/01 Javascript
Bootstrap基本组件学习笔记之进度条(15)
2016/12/08 Javascript
js cookie实现记住密码功能
2017/01/17 Javascript
浅谈VUE监听窗口变化事件的问题
2018/02/24 Javascript
vue中如何实现pdf文件预览的方法
2018/07/12 Javascript
Vue将页面导出为图片或者PDF
2020/08/17 Javascript
使用Typescript和ES模块发布Node模块的方法
2020/05/25 Javascript
[24:42]VP vs TNC Supermajor小组赛B组 BO3 第三场 6.2
2018/06/03 DOTA
python实现在sqlite动态创建表的方法
2015/05/08 Python
浅谈python中列表、字符串、字典的常用操作
2017/09/19 Python
Python实现的双色球生成功能示例
2017/12/18 Python
Python中elasticsearch插入和更新数据的实现方法
2018/04/01 Python
Python实现的网页截图功能【PyQt4与selenium组件】
2018/07/12 Python
python将一组数分成每3个一组的实例
2018/11/14 Python
在win10和linux上分别安装Python虚拟环境的方法步骤
2019/05/09 Python
解决webdriver.Chrome()报错:Message:'chromedriver' executable needs to be in Path
2019/06/12 Python
Python Numpy中数据的常用保存与读取方法
2020/04/01 Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
2020/06/24 Python
python实现取余操作的简单实例
2020/08/16 Python
详解CSS3中强大的filter(滤镜)属性
2017/06/29 HTML / CSS
90后毕业生的求职信范文
2013/09/21 职场文书
新店开张活动方案
2014/08/24 职场文书
2014年扶贫帮困工作总结
2014/12/09 职场文书
辅导员学期工作总结
2015/08/14 职场文书
如何使用Python提取Chrome浏览器保存的密码
2021/06/09 Python
python数字图像处理:图像简单滤波
2022/06/28 Python