双向RNN:bidirectional_dynamic_rnn()函数的使用详解


Posted in Python onJanuary 20, 2020

双向RNN:bidirectional_dynamic_rnn()函数的使用详解

先说下为什么要使用到双向RNN,在读一篇文章的时候,上文提到的信息十分的重要,但这些信息是不足以捕捉文章信息的,下文隐含的信息同样会对该时刻的语义产生影响。

举一个不太恰当的例子,某次工作会议上,领导进行“简洁地”总结,他会在第一句告诉你:“下面,为了节约时间,我简单地说两点…”,(…此处略去五百字…),“首先,….”,(…此处略去一万字…),“碍于时间的关系,我要加快速度了,下面我简要说下第二点…”(…此处再次略去五千字…)“好的,我想说的大概就是这些”(…此处又略去了二百字…),“谢谢大家!”如果将这篇发言交给一个单层的RNN网络去学习,因为“首先”和“第二点”中间隔得实在太久,等到开始学习“第二点”时,网络已经忘记了“简单地说两点”这个重要的信息,最终的结果就只剩下在风中凌乱了。。。于是我们决定加一个反向的网络,从后开始往前听,对于这层网络,他首先听到的就是“第二点”,然后是“首先”,最后,他对比了一下果然仅仅是“简要地两点”,在于前向的网络进行结合,就深入学习了领导的指导精神。

双向RNN:bidirectional_dynamic_rnn()函数的使用详解

上图是一个双向LSTM的结构图,对于最后输出的每个隐藏状态双向RNN:bidirectional_dynamic_rnn()函数的使用详解 都是前向网络和后向网络的元组,即双向RNN:bidirectional_dynamic_rnn()函数的使用详解 其中每一个双向RNN:bidirectional_dynamic_rnn()函数的使用详解 或者双向RNN:bidirectional_dynamic_rnn()函数的使用详解 又是一个由隐藏状态和细胞状态组成的元组(或者是concat)。同样最终的output也是需要将前向和后向的输出concat起来的,这样就保证了在最终时刻,无论是输出还是隐藏状态都是有考虑了上文和下文信息的。

下面就来看下tensorflow中已经集成的 tf.nn.bidirectional_dynamic_rnn() 函数。似乎双向的暂时只有这一个动态的RNN方法,不过想想也能理解,这种结构暂时也只会在encoder端出现,无论你的输入是pad到了定长或者是不定长的,动态RNN都是可以处理的。

具体的定义如下:

tf.nn.bidirectional_dynamic_rnn(
 cell_fw,
 cell_bw,
 inputs,
 sequence_length=None,
 initial_state_fw=None,
 initial_state_bw=None,
 dtype=None,
 parallel_iterations=None,
 swap_memory=False,
 time_major=False,
 scope=None
)

仔细看这个方法似乎和dynamic_rnn()没有太大区别,无非是多加了一个bw的部分,事实上也的确如此。先看下前向传播的部分:

with vs.variable_scope(scope or "bidirectional_rnn"):
 # Forward direction
 with vs.variable_scope("fw") as fw_scope:
  output_fw, output_state_fw = dynamic_rnn(
    cell=cell_fw, inputs=inputs, 
    sequence_length=sequence_length,
    initial_state=initial_state_fw, 
    dtype=dtype,
    parallel_iterations=parallel_iterations, 
    swap_memory=swap_memory,
    scope=fw_scope)

完全就是一个dynamic_rnn(),至于你选择LSTM或者GRU,只是cell的定义不同罢了。而双向RNN的核心就在于反向的bw部分。刚才说过,反向部分就是从后往前读,而这个翻转的部分,就要用到一个reverse_sequence()的方法,来看一下这一部分:

with vs.variable_scope("bw") as bw_scope:
 # ———————————— 此处是重点 ———————————— 
 inputs_reverse = _reverse(
   inputs, seq_lengths=sequence_length,
   seq_dim=time_dim, batch_dim=batch_dim)
 # ————————————————————————————————————
 tmp, output_state_bw = dynamic_rnn(
   cell=cell_bw, 
   inputs=inputs_reverse, 
   sequence_length=sequence_length,
   initial_state=initial_state_bw, 
   dtype=dtype,
   parallel_iterations=parallel_iterations,
   swap_memory=swap_memory,
   time_major=time_major, 
   scope=bw_scope)

我们可以看到,这里的输入不再是inputs,而是一个inputs_reverse,根据time_major的取值,time_dim和batch_dim组合的 {0,1} 取值正好相反,也就对应了时间维和批量维的词序关系。

而最终的输出:

outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)

这里还有最后的一个小问题,output_states是一个元组的元组,我个人的处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态。

以上这篇双向RNN:bidirectional_dynamic_rnn()函数的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现简单socket通信的方法
Apr 19 Python
关于python pyqt5安装失败问题的解决方法
Aug 08 Python
flask入门之文件上传与邮件发送示例
Jul 18 Python
python中协程实现TCP连接的实例分析
Oct 14 Python
详解基于python的多张不同宽高图片拼接成大图
Sep 26 Python
如何使用Python脚本实现文件拷贝
Nov 20 Python
Python用类实现扑克牌发牌的示例代码
Jun 01 Python
python自定义函数def的应用详解
Jun 03 Python
使用keras时input_shape的维度表示问题说明
Jun 29 Python
python实现xlwt xlrd 指定条件给excel行添加颜色
Jul 14 Python
Python3爬虫里关于代理的设置总结
Jul 30 Python
基于PyTorch实现一个简单的CNN图像分类器
May 29 Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
You might like
php动态函数调用方法
2015/05/21 PHP
分享五个PHP7性能优化提升技巧
2015/12/07 PHP
thinkphp3.x中cookie方法的用法分析
2016/05/19 PHP
简单解决微信文章图片防盗链问题
2016/12/17 PHP
浅谈PHP接入(第三方登录)QQ登录 OAuth2.0 过程中遇到的坑
2017/10/13 PHP
Javascript 各浏览器的 Javascript 效率对比
2008/01/23 Javascript
javascript自执行函数之伪命名空间封装法
2010/12/25 Javascript
JS中prototype关键字的功能介绍及使用示例
2013/07/21 Javascript
通过一段代码简单说js中的this的使用
2013/07/23 Javascript
利用jquery.qrcode在页面上生成二维码且支持中文
2014/02/12 Javascript
鼠标经过子元素触发mouseout,mouseover事件的解决方案
2015/07/26 Javascript
JavaScript常用函数工具集:lao-utils
2016/03/01 Javascript
EasyUI中在表单提交之前进行验证
2016/07/19 Javascript
Angular Module声明和获取重载实例代码
2016/09/14 Javascript
Bootstrap modal使用及点击外部不消失的解决方法
2016/12/13 Javascript
Ionic项目中Native Camera的使用方法
2017/06/07 Javascript
详解vue服务端渲染(SSR)初探
2017/06/19 Javascript
微信小程序解析富文本过程详解
2019/07/13 Javascript
js实现旋转的星空效果
2019/11/01 Javascript
[01:01:51]EG vs VG Supermajor小组赛B组 BO3 第二场 6.2
2018/06/03 DOTA
python开发环境PyScripter中文乱码问题解决方案
2016/09/11 Python
Python使用字典的嵌套功能详解
2019/02/27 Python
Python实现将字符串的首字母变为大写,其余都变为小写的方法
2019/06/11 Python
PyQt编程之如何在屏幕中央显示窗体的实例
2019/06/18 Python
Pandas时间序列:时期(period)及其算术运算详解
2020/02/25 Python
Python常用类型转换实现代码实例
2020/07/28 Python
python基于selenium爬取斗鱼弹幕
2021/02/20 Python
雪花秀美国官方网站:韩国著名草本护肤化妆品品牌
2016/10/19 全球购物
什么是属性访问器
2015/10/26 面试题
终止或解除劳动合同及劳动关系的证明书
2014/10/06 职场文书
个人简历求职信范文
2015/03/20 职场文书
社区文明倡议书
2015/04/28 职场文书
2015年全国助残日活动方案
2015/05/04 职场文书
党支部评议意见
2015/06/02 职场文书
雷锋观后感
2015/06/10 职场文书
浅谈redis的过期时间设置和过期删除机制
2022/03/18 MySQL