变长双向rnn的正确使用姿势教学


Posted in Python onMay 31, 2021

如何使用双向RNN

在《深度学习之TensorFlow入门、原理与进阶实战》一书的9.4.2中的第4小节中,介绍过变长动态RNN的实现。

这里在来延伸的讲解一下双向动态rnn在处理变长序列时的应用。其实双向RNN的使用中,有一个隐含的注意事项,非常容易犯错。

本文就在介绍下双向RNN的常用函数、用法及注意事项。

动态双向rnn有两个函数:

stack_bidirectional_dynamic_rnn
bidirectional_dynamic_rnn

二者的实现上大同小异,放置的位置也不一样,前者放在contrib下面,而后者显得更加根红苗正,放在了tf的核心库下面。在使用时二者的返回值也有所区别。下面就来一一介绍。

示例代码

先以GRU的cell代码为例:

import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 创建输入数据
X = np.random.randn(2, 4, 5)# 批次 、序列长度、样本维度
# 第二个样本长度为3
X[1,2:] = 0
seq_lengths = [4, 2]
Gstacked_rnn = []
Gstacked_bw_rnn = []
for i in range(3):
    Gstacked_rnn.append(tf.contrib.rnn.GRUCell(3))
    Gstacked_bw_rnn.append(tf.contrib.rnn.GRUCell(3))
#建立前向和后向的三层RNN
Gmcell = tf.contrib.rnn.MultiRNNCell(Gstacked_rnn)
Gmcell_bw = tf.contrib.rnn.MultiRNNCell(Gstacked_bw_rnn)
sGbioutputs, sGoutput_state_fw, sGoutput_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn([Gmcell],[Gmcell_bw], X,sequence_length=seq_lengths,                                           dtype=tf.float64)
Gbioutputs, Goutput_state_fw = tf.nn.bidirectional_dynamic_rnn(Gmcell,Gmcell_bw, X,sequence_length=seq_lengths,dtype=tf.float64)

上面例子中是创建双向RNN的方法示例。可以看到带有stack的双向RNN会输出3个返回值,而不带有stack的双向RNN会输出2个返回值。

这里面还要注意的是,在没有未cell初始化时必须要将dtype参数赋值。不然会报错。

代码:BiRNN输出

下面添加代码,将输出的值打印出来,看一下,这两个函数到底是输出的是啥?

#建立一个会话
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sgbresult,sgstate_fw,sgstate_bw=sess.run([sGbioutputs,sGoutput_state_fw,sGoutput_state_bw])
print("全序列:\n", sgbresult[0])
print("短序列:\n", sgbresult[1])
print('Gru的状态:',len(sgstate_fw[0]),'\n',sgstate_fw[0][0],'\n',sgstate_fw[0][1],'\n',sgstate_fw[0][2])
print('Gru的状态:',len(sgstate_bw[0]),'\n',sgstate_bw[0][0],'\n',sgstate_bw[0][1],'\n',sgstate_bw[0][2])

先看一下带有stack的双向RNN输出的内容:

变长双向rnn的正确使用姿势教学

我们输入的数据的批次是2,第一个序列长度是4,第二个序列长度是2.

图中共有4部分输出,可以看到,第一部分(全序列)就是序列长度为4的结果,第二部分(短序列)就是序列长度为2的结果。由于没一层都是由3个RNN的GRU cell组成,所以每个序列的输出都为3.很显然,对于这样的结果输出,必须要将短序列后面的0去掉才可以用。

好在该函数还有第二个输出值,GRU的状态。可以直接使用状态里的值,而不需要对原始结果进行去0的变化。

由于单个GRU本来就是没有状态的。所以该函数将最后的输出作为状态返回。该函数有两个状态返回,分别代表前向和后向。每一个方向的状态都会返回3个元素。这是因为每个方向的网络都有3层GRU组成。在使用时,一般都会取最后一个状态。图中红色部分为前向中,两个样本对应的输出,这个很好理解。

重点要看蓝色的部分,即反向的状态值对应的是原始数据中最其实的序列输入。因为是反向RNN,在反向循环时,是会把序列中最后的放在最前面,所以反向网络的生成结果就会与最开始的序列相对应。

对于特征提取任务处理时,正向与反向的最后值都为该序列的特征,需要合并起来统一处理。但是对于下一个序列预测任务时,建议直接使用正向的RNN网络就可以了。

如果要获取双向RNN的结果,尤其是变长情况下,通过状态拿到值直接拼接起来才是正确的做法。即便不是变长。直接使用输出值来拼接,会损失掉反向的一部分特征结果。这是需要值得注意的地方。

代码:BiRNN输出

好了。在接着看下不带stack的函数输出是什么样子的

gbresult,state_fw=sess.run([Gbioutputs,Goutput_state_fw])
print("正向:\n", gbresult[0])
print("反向:\n", gbresult[1])
print('状态:',len(state_fw),'\n',state_fw[0],'\n',state_fw[1])  #state_fw[0]:【层,批次,cell个数】 重头到最后一个序列
print(state_fw[0][-1],state_fw[1][-1])
out  = np.concatenate((state_fw[0][-1],state_fw[1][-1]),axis = 1)
print("拼接",out)

这次,在输出基本内容基础上,直接将结果拼接起来。上面代码运行后会输出如下内容。

变长双向rnn的正确使用姿势教学

同样正向用红色,反向用蓝色。改函数返回的输出值,没有将正反向拼接。输出的状态虽然是一个值,但是里面有两个元素,一个代表正向状态,一个代表反向状态.

从输出中可以看到,最后一行实现了最终结果的真正拼接。在使用双向rnn时可以按照上面的例子代码将其状态拼接成一条完整输出,然后在进行处理。

代码:LSTM的双向RNN

类似的如果想使用LSTM cell。将前面的GRU部分替换即可,代码如下:

stacked_rnn = []
stacked_bw_rnn = []
for i in range(3):
    stacked_rnn.append(tf.contrib.rnn.LSTMCell(3))
    stacked_bw_rnn.append(tf.contrib.rnn.LSTMCell(3))
mcell = tf.contrib.rnn.MultiRNNCell(stacked_rnn)
mcell_bw = tf.contrib.rnn.MultiRNNCell(stacked_bw_rnn)    
bioutputs, output_state_fw, output_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn([mcell],[mcell_bw], X,sequence_length=seq_lengths,
                                              dtype=tf.float64)
bioutputs, output_state_fw = tf.nn.bidirectional_dynamic_rnn(mcell,mcell_bw, X,sequence_length=seq_lengths,
                                              dtype=tf.float64)

至于输出的内容是什么,可以按照前面GRU的输出部分显示出来自己观察。如何拼接,也可以参照GRU的例子来做。

通过将正反向的状态拼接起来才可以获得双向RNN的最终输出特征。千万不要直接拿着输出不加处理的来进行后续的运算,这会损失一大部分的运算特征。

该部分内容属于《深度学习之TensorFlow入门、原理与进阶实战》一书的内容补充。关于RNN的更多介绍可以参看书中第九章的详细内容。

我对双向RNN 的理解

1、双向RNN使用的场景:有些情况下,当前的输出不只依赖于之前的序列元素,还可能依赖之后的序列元素; 比如做完形填空,机器翻译等应用。

变长双向rnn的正确使用姿势教学

2、Tensorflow 中实现双向RNN 的API是:bidirectional_dynamic_rnn; 其本质主要是做了两次reverse:

第一次reverse:将输入序列进行reverse,然后送入dynamic_rnn做一次运算.

第二次reverse:将上面dynamic_rnn返回的outputs进行reverse,保证正向和反向输出的time是对上的.

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程实现蚁群算法详解
Nov 13 Python
Python enumerate索引迭代代码解析
Jan 19 Python
Python装饰器简单用法实例小结
Dec 03 Python
python从入门到精通 windows安装python图文教程
May 18 Python
详解将Python程序(.py)转换为Windows可执行文件(.exe)
Jul 19 Python
Django urls.py重构及参数传递详解
Jul 23 Python
使用OpCode绕过Python沙箱的方法详解
Sep 03 Python
详解python itertools功能
Feb 07 Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 Python
python如何提升爬虫效率
Sep 27 Python
Python爬虫破解登陆哔哩哔哩的方法
Nov 17 Python
Python OpenCV实现传统图片格式与base64转换
Jun 13 Python
如何在Python项目中引入日志
Tensorflow与RNN、双向LSTM等的踩坑记录及解决
Python数据类型最全知识总结
May 31 #Python
教你怎么用Python操作MySql数据库
Django集成富文本编辑器summernote的实现步骤
Python基础知识学习之类的继承
May 31 #Python
Django实现聊天机器人
You might like
超强分页类2.0发布,支持自定义风格,默认4种显示模式
2007/01/02 PHP
php生成N个不重复的随机数实例
2013/11/12 PHP
thinkphp浏览历史功能实现方法
2014/10/29 PHP
php使用Jpgraph创建折线图效果示例
2017/02/15 PHP
PHP后门隐藏的一些技巧总结
2020/11/04 PHP
PHP中的输出echo、print、printf、sprintf、print_r和var_dump的示例代码
2020/12/01 PHP
番茄的表单验证类代码修改版
2008/07/18 Javascript
javascript function、指针及内置对象
2009/02/19 Javascript
Javascript的时间戳和php的时间戳转换注意事项
2013/04/12 Javascript
JS实现两个大数(整数)相乘
2014/04/28 Javascript
JS访问SWF的函数用法实例
2015/07/01 Javascript
原生JS封装Ajax插件(同域、jsonp跨域)
2016/05/03 Javascript
浅谈jquery设置和获得checkbox选中的问题
2016/08/19 Javascript
Vue.js使用$.ajax和vue-resource实现OAuth的注册、登录、注销和API调用
2017/05/10 Javascript
react native与webview通信的示例代码
2017/09/25 Javascript
vue-router 手势滑动触发返回功能
2018/09/30 Javascript
Bootstrap table 服务器端分页功能实现方法示例
2020/06/01 Javascript
js实现飞机大战游戏
2020/08/26 Javascript
[00:32]2018DOTA2亚洲邀请赛Mineski出场
2018/04/04 DOTA
Python中的进程分支fork和exec详解
2015/04/11 Python
Django中传递参数到URLconf的视图函数中的方法
2015/07/18 Python
Python信息抽取之乱码解决办法
2017/06/29 Python
python虚拟环境virtualenv的安装与使用
2017/09/21 Python
python中numpy的矩阵、多维数组的用法
2018/02/05 Python
python中partial()基础用法说明
2018/12/30 Python
Python使用itchat 功能分析微信好友性别和位置
2019/08/05 Python
html5 css3 动态气泡按钮实例演示
2012/12/02 HTML / CSS
基于CSS3的animation属性实现微信拍一拍动画效果
2020/06/22 HTML / CSS
HTML5手指下滑弹出负一屏阻止移动端浏览器内置下拉刷新功能的实现代码
2020/04/10 HTML / CSS
美国流行背包品牌:JanSport(杰斯伯)
2018/03/02 全球购物
您附近的水疗和健康场所:Spafinder(美国)
2019/07/05 全球购物
自荐书范文
2013/12/08 职场文书
麦当劳辞职信范文
2014/01/18 职场文书
银行开户授权委托书格式
2014/10/10 职场文书
python 中yaml文件用法大全
2021/07/04 Python
MySQL深分页问题解决思路
2022/12/24 MySQL