双向RNN:bidirectional_dynamic_rnn()函数的使用详解


Posted in Python onJanuary 20, 2020

双向RNN:bidirectional_dynamic_rnn()函数的使用详解

先说下为什么要使用到双向RNN,在读一篇文章的时候,上文提到的信息十分的重要,但这些信息是不足以捕捉文章信息的,下文隐含的信息同样会对该时刻的语义产生影响。

举一个不太恰当的例子,某次工作会议上,领导进行“简洁地”总结,他会在第一句告诉你:“下面,为了节约时间,我简单地说两点…”,(…此处略去五百字…),“首先,….”,(…此处略去一万字…),“碍于时间的关系,我要加快速度了,下面我简要说下第二点…”(…此处再次略去五千字…)“好的,我想说的大概就是这些”(…此处又略去了二百字…),“谢谢大家!”如果将这篇发言交给一个单层的RNN网络去学习,因为“首先”和“第二点”中间隔得实在太久,等到开始学习“第二点”时,网络已经忘记了“简单地说两点”这个重要的信息,最终的结果就只剩下在风中凌乱了。。。于是我们决定加一个反向的网络,从后开始往前听,对于这层网络,他首先听到的就是“第二点”,然后是“首先”,最后,他对比了一下果然仅仅是“简要地两点”,在于前向的网络进行结合,就深入学习了领导的指导精神。

双向RNN:bidirectional_dynamic_rnn()函数的使用详解

上图是一个双向LSTM的结构图,对于最后输出的每个隐藏状态双向RNN:bidirectional_dynamic_rnn()函数的使用详解 都是前向网络和后向网络的元组,即双向RNN:bidirectional_dynamic_rnn()函数的使用详解 其中每一个双向RNN:bidirectional_dynamic_rnn()函数的使用详解 或者双向RNN:bidirectional_dynamic_rnn()函数的使用详解 又是一个由隐藏状态和细胞状态组成的元组(或者是concat)。同样最终的output也是需要将前向和后向的输出concat起来的,这样就保证了在最终时刻,无论是输出还是隐藏状态都是有考虑了上文和下文信息的。

下面就来看下tensorflow中已经集成的 tf.nn.bidirectional_dynamic_rnn() 函数。似乎双向的暂时只有这一个动态的RNN方法,不过想想也能理解,这种结构暂时也只会在encoder端出现,无论你的输入是pad到了定长或者是不定长的,动态RNN都是可以处理的。

具体的定义如下:

tf.nn.bidirectional_dynamic_rnn(
 cell_fw,
 cell_bw,
 inputs,
 sequence_length=None,
 initial_state_fw=None,
 initial_state_bw=None,
 dtype=None,
 parallel_iterations=None,
 swap_memory=False,
 time_major=False,
 scope=None
)

仔细看这个方法似乎和dynamic_rnn()没有太大区别,无非是多加了一个bw的部分,事实上也的确如此。先看下前向传播的部分:

with vs.variable_scope(scope or "bidirectional_rnn"):
 # Forward direction
 with vs.variable_scope("fw") as fw_scope:
  output_fw, output_state_fw = dynamic_rnn(
    cell=cell_fw, inputs=inputs, 
    sequence_length=sequence_length,
    initial_state=initial_state_fw, 
    dtype=dtype,
    parallel_iterations=parallel_iterations, 
    swap_memory=swap_memory,
    scope=fw_scope)

完全就是一个dynamic_rnn(),至于你选择LSTM或者GRU,只是cell的定义不同罢了。而双向RNN的核心就在于反向的bw部分。刚才说过,反向部分就是从后往前读,而这个翻转的部分,就要用到一个reverse_sequence()的方法,来看一下这一部分:

with vs.variable_scope("bw") as bw_scope:
 # ———————————— 此处是重点 ———————————— 
 inputs_reverse = _reverse(
   inputs, seq_lengths=sequence_length,
   seq_dim=time_dim, batch_dim=batch_dim)
 # ————————————————————————————————————
 tmp, output_state_bw = dynamic_rnn(
   cell=cell_bw, 
   inputs=inputs_reverse, 
   sequence_length=sequence_length,
   initial_state=initial_state_bw, 
   dtype=dtype,
   parallel_iterations=parallel_iterations,
   swap_memory=swap_memory,
   time_major=time_major, 
   scope=bw_scope)

我们可以看到,这里的输入不再是inputs,而是一个inputs_reverse,根据time_major的取值,time_dim和batch_dim组合的 {0,1} 取值正好相反,也就对应了时间维和批量维的词序关系。

而最终的输出:

outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)

这里还有最后的一个小问题,output_states是一个元组的元组,我个人的处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态。

以上这篇双向RNN:bidirectional_dynamic_rnn()函数的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫抓取手机APP的传输数据
Jan 22 Python
python生成器表达式和列表解析
Mar 10 Python
Python实现简单的语音识别系统
Dec 13 Python
Python3实现将本地JSON大数据文件写入MySQL数据库的方法
Jun 13 Python
解决python ogr shp字段写入中文乱码的问题
Dec 31 Python
pycharm中显示CSS提示的知识点总结
Jul 29 Python
使用OpenCV-python3实现滑动条更新图像的Canny边缘检测功能
Dec 12 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 Python
python中导入 train_test_split提示错误的解决
Jun 19 Python
Python爬虫之App爬虫视频下载的实现
Dec 08 Python
基于Python编写简易版的天天跑酷游戏的示例代码
Mar 23 Python
Python OpenGL基本配置方式
May 20 Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
You might like
Cappuccino 卡布其诺咖啡之制作
2021/03/03 冲泡冲煮
php empty()与isset()区别的详细介绍
2013/06/17 PHP
解析PHP跨站刷票的实现代码
2013/06/18 PHP
PHP使用ActiveMQ实现消息队列的方法详解
2019/05/31 PHP
js获取图片大小的函数代码
2011/09/20 Javascript
js加强的经典分页实例
2013/03/15 Javascript
jquery解析XML字符串和XML文件的方法说明
2014/02/21 Javascript
js 单引号替换成双引号,双引号替换成单引号的实现方法
2017/02/16 Javascript
Vue 2.x教程之基础API
2017/03/06 Javascript
js实现不提示直接关闭网页窗口
2017/03/30 Javascript
Bootstrap Table使用整理(二)
2017/06/09 Javascript
JavaScript变量作用域_动力节点Java学院整理
2017/06/27 Javascript
JavaScript实现无刷新上传预览图片功能
2017/08/02 Javascript
vue + webpack如何绕过QQ音乐接口对host的验证详解
2018/07/01 Javascript
JavaScript事件对象深入详解
2018/12/30 Javascript
js 下拉菜单点击旁边收起实现(踩坑记)
2019/09/29 Javascript
原生js canvas实现鼠标跟随效果
2020/08/02 Javascript
vue修改Element的el-table样式的4种方法
2020/09/17 Javascript
vue 获取url里参数的两种方法小结
2020/11/12 Javascript
Vue实现小购物车功能
2020/12/21 Vue.js
js 执行上下文和作用域的相关总结
2021/02/08 Javascript
[03:46]显微镜下的DOTA2第七期——满血与残血
2014/06/20 DOTA
[00:35]DOTA2上海特级锦标赛 Newbee战队宣传片
2016/03/03 DOTA
教大家玩转Python字符串处理的七种技巧
2017/03/31 Python
详解 Python中LEGB和闭包及装饰器
2017/08/03 Python
wx.CheckBox创建复选框控件并响应鼠标点击事件
2018/04/25 Python
Python实现的多项式拟合功能示例【基于matplotlib】
2018/05/15 Python
python 根据字典的键值进行排序的方法
2019/07/24 Python
python 画3维轨迹图并进行比较的实例
2019/12/06 Python
python使用列表的最佳方案
2020/08/12 Python
全网最细 Python 格式化输出用法讲解(推荐)
2021/01/18 Python
复古斯堪的纳维亚儿童服装:Baby go Retro
2017/09/09 全球购物
数据库方面面试题
2012/04/22 面试题
故意伤害人身损害赔偿协议书
2014/11/19 职场文书
使用 JavaScript 制作页面效果
2021/04/21 Javascript
Java后台生成图片的完整步骤
2021/08/04 Java/Android