pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python批量修改文件后缀的方法
Jan 26 Python
Python中字典的基础知识归纳小结
Aug 19 Python
Python冒泡排序注意要点实例详解
Sep 09 Python
Python搭建FTP服务器的方法示例
Jan 19 Python
详解pandas库pd.read_excel操作读取excel文件参数整理与实例
Feb 17 Python
利用anaconda作为python的依赖库管理方法
Aug 13 Python
python openvc 裁剪、剪切图片 提取图片的行和列
Sep 19 Python
python 3.74 运行import numpy as np 报错lib\site-packages\numpy\__init__.py
Oct 06 Python
Python如何发送与接收大型数组
Aug 07 Python
python 带时区的日期格式化操作
Oct 23 Python
完美解决torch.cuda.is_available()一直返回False的玄学方法
Feb 06 Python
python的netCDF4批量处理NC格式文件的操作方法
Mar 21 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
PHP的一个基础知识 表单提交
2011/07/04 PHP
PHP版QQ互联OAuth示例代码分享
2015/07/05 PHP
PHP人民币金额转大写实例代码
2015/10/02 PHP
PHP 使用二进制保存用户状态的实例
2018/01/29 PHP
推荐:极酷右键菜单
2006/11/29 Javascript
setTimeout 不断吐食CPU的问题分析
2009/04/01 Javascript
原生JS实现表单checkbook获取已选择的值
2013/07/21 Javascript
jquery事件preventDefault()方法用法实例
2015/01/16 Javascript
纯JS代码实现一键分享功能
2016/04/20 Javascript
jquery表格datatables实例解析 直接加载和延迟加载
2016/08/12 Javascript
AngularJS使用自定义指令替代ng-repeat的方法
2016/09/17 Javascript
浅谈js键盘事件全面控制
2016/12/01 Javascript
使用JavaScript触发过渡效果的方法
2017/01/19 Javascript
thinkphp标签实现bootsrtap轮播carousel实例代码
2017/02/19 Javascript
vue2.0在没有dev-server.js下的本地数据配置方法
2018/02/23 Javascript
详解.vue文件解析的实现
2018/06/11 Javascript
微信小程序MUI侧滑导航菜单示例(Popup弹出式,左侧不动,右侧滑动)
2019/01/23 Javascript
JavaScript解析JSON数据示例
2019/07/16 Javascript
浅谈Vue项目骨架屏注入实践
2019/08/05 Javascript
vue实现设置载入动画和初始化页面动画效果
2019/10/28 Javascript
vue实现防抖的实例代码
2021/01/11 Vue.js
教你用Type Hint提高Python程序开发效率
2016/08/08 Python
[原创]教女朋友学Python(一)运行环境搭建
2017/11/29 Python
Python函数参数操作详解
2018/08/03 Python
Python发送邮件测试报告操作实例详解
2018/12/08 Python
PyCharm搭建Spark开发环境实现第一个pyspark程序
2019/06/13 Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
2019/08/17 Python
Python模拟登录和登录跳转的参考示例
2020/10/30 Python
如何在 Matplotlib 中更改绘图背景的实现
2020/11/26 Python
移动HTML5前端框架—MUI的使用
2017/12/18 HTML / CSS
jurlique茱莉蔻英国官网:澳洲天然护肤品
2018/08/03 全球购物
美国演唱会和体育门票购买网站:Ticketnetwork
2018/10/19 全球购物
Linux开机引导的步骤是什么
2014/02/26 面试题
后勤采购员岗位职责
2013/12/19 职场文书
中秋联欢会主持词
2015/07/04 职场文书
酒桌上的祝酒词
2015/08/12 职场文书