变长双向rnn的正确使用姿势教学


Posted in Python onMay 31, 2021

如何使用双向RNN

在《深度学习之TensorFlow入门、原理与进阶实战》一书的9.4.2中的第4小节中,介绍过变长动态RNN的实现。

这里在来延伸的讲解一下双向动态rnn在处理变长序列时的应用。其实双向RNN的使用中,有一个隐含的注意事项,非常容易犯错。

本文就在介绍下双向RNN的常用函数、用法及注意事项。

动态双向rnn有两个函数:

stack_bidirectional_dynamic_rnn
bidirectional_dynamic_rnn

二者的实现上大同小异,放置的位置也不一样,前者放在contrib下面,而后者显得更加根红苗正,放在了tf的核心库下面。在使用时二者的返回值也有所区别。下面就来一一介绍。

示例代码

先以GRU的cell代码为例:

import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 创建输入数据
X = np.random.randn(2, 4, 5)# 批次 、序列长度、样本维度
# 第二个样本长度为3
X[1,2:] = 0
seq_lengths = [4, 2]
Gstacked_rnn = []
Gstacked_bw_rnn = []
for i in range(3):
    Gstacked_rnn.append(tf.contrib.rnn.GRUCell(3))
    Gstacked_bw_rnn.append(tf.contrib.rnn.GRUCell(3))
#建立前向和后向的三层RNN
Gmcell = tf.contrib.rnn.MultiRNNCell(Gstacked_rnn)
Gmcell_bw = tf.contrib.rnn.MultiRNNCell(Gstacked_bw_rnn)
sGbioutputs, sGoutput_state_fw, sGoutput_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn([Gmcell],[Gmcell_bw], X,sequence_length=seq_lengths,                                           dtype=tf.float64)
Gbioutputs, Goutput_state_fw = tf.nn.bidirectional_dynamic_rnn(Gmcell,Gmcell_bw, X,sequence_length=seq_lengths,dtype=tf.float64)

上面例子中是创建双向RNN的方法示例。可以看到带有stack的双向RNN会输出3个返回值,而不带有stack的双向RNN会输出2个返回值。

这里面还要注意的是,在没有未cell初始化时必须要将dtype参数赋值。不然会报错。

代码:BiRNN输出

下面添加代码,将输出的值打印出来,看一下,这两个函数到底是输出的是啥?

#建立一个会话
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sgbresult,sgstate_fw,sgstate_bw=sess.run([sGbioutputs,sGoutput_state_fw,sGoutput_state_bw])
print("全序列:\n", sgbresult[0])
print("短序列:\n", sgbresult[1])
print('Gru的状态:',len(sgstate_fw[0]),'\n',sgstate_fw[0][0],'\n',sgstate_fw[0][1],'\n',sgstate_fw[0][2])
print('Gru的状态:',len(sgstate_bw[0]),'\n',sgstate_bw[0][0],'\n',sgstate_bw[0][1],'\n',sgstate_bw[0][2])

先看一下带有stack的双向RNN输出的内容:

变长双向rnn的正确使用姿势教学

我们输入的数据的批次是2,第一个序列长度是4,第二个序列长度是2.

图中共有4部分输出,可以看到,第一部分(全序列)就是序列长度为4的结果,第二部分(短序列)就是序列长度为2的结果。由于没一层都是由3个RNN的GRU cell组成,所以每个序列的输出都为3.很显然,对于这样的结果输出,必须要将短序列后面的0去掉才可以用。

好在该函数还有第二个输出值,GRU的状态。可以直接使用状态里的值,而不需要对原始结果进行去0的变化。

由于单个GRU本来就是没有状态的。所以该函数将最后的输出作为状态返回。该函数有两个状态返回,分别代表前向和后向。每一个方向的状态都会返回3个元素。这是因为每个方向的网络都有3层GRU组成。在使用时,一般都会取最后一个状态。图中红色部分为前向中,两个样本对应的输出,这个很好理解。

重点要看蓝色的部分,即反向的状态值对应的是原始数据中最其实的序列输入。因为是反向RNN,在反向循环时,是会把序列中最后的放在最前面,所以反向网络的生成结果就会与最开始的序列相对应。

对于特征提取任务处理时,正向与反向的最后值都为该序列的特征,需要合并起来统一处理。但是对于下一个序列预测任务时,建议直接使用正向的RNN网络就可以了。

如果要获取双向RNN的结果,尤其是变长情况下,通过状态拿到值直接拼接起来才是正确的做法。即便不是变长。直接使用输出值来拼接,会损失掉反向的一部分特征结果。这是需要值得注意的地方。

代码:BiRNN输出

好了。在接着看下不带stack的函数输出是什么样子的

gbresult,state_fw=sess.run([Gbioutputs,Goutput_state_fw])
print("正向:\n", gbresult[0])
print("反向:\n", gbresult[1])
print('状态:',len(state_fw),'\n',state_fw[0],'\n',state_fw[1])  #state_fw[0]:【层,批次,cell个数】 重头到最后一个序列
print(state_fw[0][-1],state_fw[1][-1])
out  = np.concatenate((state_fw[0][-1],state_fw[1][-1]),axis = 1)
print("拼接",out)

这次,在输出基本内容基础上,直接将结果拼接起来。上面代码运行后会输出如下内容。

变长双向rnn的正确使用姿势教学

同样正向用红色,反向用蓝色。改函数返回的输出值,没有将正反向拼接。输出的状态虽然是一个值,但是里面有两个元素,一个代表正向状态,一个代表反向状态.

从输出中可以看到,最后一行实现了最终结果的真正拼接。在使用双向rnn时可以按照上面的例子代码将其状态拼接成一条完整输出,然后在进行处理。

代码:LSTM的双向RNN

类似的如果想使用LSTM cell。将前面的GRU部分替换即可,代码如下:

stacked_rnn = []
stacked_bw_rnn = []
for i in range(3):
    stacked_rnn.append(tf.contrib.rnn.LSTMCell(3))
    stacked_bw_rnn.append(tf.contrib.rnn.LSTMCell(3))
mcell = tf.contrib.rnn.MultiRNNCell(stacked_rnn)
mcell_bw = tf.contrib.rnn.MultiRNNCell(stacked_bw_rnn)    
bioutputs, output_state_fw, output_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn([mcell],[mcell_bw], X,sequence_length=seq_lengths,
                                              dtype=tf.float64)
bioutputs, output_state_fw = tf.nn.bidirectional_dynamic_rnn(mcell,mcell_bw, X,sequence_length=seq_lengths,
                                              dtype=tf.float64)

至于输出的内容是什么,可以按照前面GRU的输出部分显示出来自己观察。如何拼接,也可以参照GRU的例子来做。

通过将正反向的状态拼接起来才可以获得双向RNN的最终输出特征。千万不要直接拿着输出不加处理的来进行后续的运算,这会损失一大部分的运算特征。

该部分内容属于《深度学习之TensorFlow入门、原理与进阶实战》一书的内容补充。关于RNN的更多介绍可以参看书中第九章的详细内容。

我对双向RNN 的理解

1、双向RNN使用的场景:有些情况下,当前的输出不只依赖于之前的序列元素,还可能依赖之后的序列元素; 比如做完形填空,机器翻译等应用。

变长双向rnn的正确使用姿势教学

2、Tensorflow 中实现双向RNN 的API是:bidirectional_dynamic_rnn; 其本质主要是做了两次reverse:

第一次reverse:将输入序列进行reverse,然后送入dynamic_rnn做一次运算.

第二次reverse:将上面dynamic_rnn返回的outputs进行reverse,保证正向和反向输出的time是对上的.

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 分析Nginx访问日志并保存到MySQL数据库实例
Mar 13 Python
Python计算一个文件里字数的方法
Jun 15 Python
Windows中安装使用Virtualenv来创建独立Python环境
May 31 Python
python使用fcntl模块实现程序加锁功能示例
Jun 23 Python
Python 3中print函数的使用方法总结
Aug 08 Python
遗传算法python版
Mar 19 Python
Django中使用session保持用户登陆连接的例子
Aug 06 Python
python实现输入任意一个大写字母生成金字塔的示例
Oct 27 Python
Pycharm激活方法及详细教程(详细且实用)
May 12 Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 Python
python 判断一组数据是否符合正态分布
Sep 23 Python
深入浅析python3 依赖倒置原则(示例代码)
Jul 09 Python
如何在Python项目中引入日志
Tensorflow与RNN、双向LSTM等的踩坑记录及解决
Python数据类型最全知识总结
May 31 #Python
教你怎么用Python操作MySql数据库
Django集成富文本编辑器summernote的实现步骤
Python基础知识学习之类的继承
May 31 #Python
Django实现聊天机器人
You might like
PHP环境搭建最新方法
2006/09/05 PHP
Http 1.1 Etag 与 Last-Modified提高php效率
2008/01/10 PHP
PHP 实现base64编码文件上传出现问题详解
2020/09/01 PHP
jquery禁止输入数字以外的字符的示例(纯数字验证码)
2014/04/10 Javascript
jquery之别踩白块游戏的简单实现
2016/07/25 Javascript
Vuejs第九篇之组件作用域及props数据传递实例详解
2016/09/05 Javascript
在localStorage中存储对象数组并读取的方法
2016/09/24 Javascript
详解vue-router 2.0 常用基础知识点之导航钩子
2017/05/10 Javascript
全面介绍vue 全家桶和项目实例
2017/12/27 Javascript
Vue.js组件间的循环引用方法示例
2017/12/27 Javascript
详解解决小程序中webview页面多层history返回问题
2019/08/20 Javascript
java实现单链表增删改查的实例代码详解
2019/08/30 Javascript
[03:56]显微镜下的DOTA2第十一期——鬼畜的死亡先知播音员
2014/06/23 DOTA
[07:59]2014DOTA2叨叨刀塔 林熊猫称被邀请赛现场盛况震撼
2014/07/21 DOTA
Python下singleton模式的实现方法
2014/07/16 Python
解决Python中由于logging模块误用导致的内存泄露
2015/04/23 Python
Python全局变量用法实例分析
2016/07/19 Python
修复CentOS7升级Python到3.6版本后yum不能正确使用的解决方法
2018/01/26 Python
解决python3中解压zip文件是文件名乱码的问题
2018/03/22 Python
Python函数any()和all()的用法及区别介绍
2018/09/14 Python
Django 通过JS实现ajax过程详解
2019/07/30 Python
基于Python安装pyecharts所遇的问题及解决方法
2019/08/12 Python
在django中实现页面倒数几秒后自动跳转的例子
2019/08/16 Python
在python中使用pymysql往mysql数据库中插入(insert)数据实例
2020/03/02 Python
Python实现验证码识别
2020/06/15 Python
推荐技术人员一款Python开源库(造数据神器)
2020/07/08 Python
几款好用的python工具库(小结)
2020/10/20 Python
Django缓存Cache使用详解
2020/11/30 Python
南京软件公司的.net程序员笔试题
2014/08/31 面试题
篮球比赛口号
2014/06/10 职场文书
校本教研活动总结
2014/07/01 职场文书
小学六一儿童节活动总结
2015/05/05 职场文书
使用numpy nonzero 找出非0元素
2021/05/14 Python
Redis可视化客户端小结
2021/06/10 Redis
Python实现滑雪小游戏
2021/09/25 Python
Redis官方可视化工具RedisInsight安装使用教程
2022/04/19 Redis