浅谈Tensorflow 动态双向RNN的输出问题


Posted in Python onJanuary 20, 2020

tf.nn.bidirectional_dynamic_rnn()

函数:

def bidirectional_dynamic_rnn(
  cell_fw, # 前向RNN
  cell_bw, # 后向RNN
  inputs, # 输入
  sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  initial_state_fw=None, # 前向的初始化状态(可选)
  initial_state_bw=None, # 后向的初始化状态(可选)
  dtype=None, # 初始化和输出的数据类型(可选)
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  # 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
  # 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
  scope=None
)

其中,

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设

time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。

output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。

output_state_fw和output_state_bw的类型为LSTMStateTuple。

LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

返回值:

元组:(outputs, output_states)

这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态

def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
  # rnn_size: rnn隐层节点数量
  # sequence_length: 数据的序列长度
  # num_layers:堆叠的rnn cell数量
  # rnn_inputs: 输入tensor
  # keep_prob:
  '''Create the encoding layer'''
  for layer in range(num_layers):
    with tf.variable_scope('encode_{}'.format(layer)):
      cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
 
      cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
 
      enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
                                  rnn_inputs,sequence_length,dtype=tf.float32)
 
  # join outputs since we are using a bidirectional RNN
  enc_output = tf.concat(enc_output,2) 
  return enc_output,enc_state

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。

outputs. outputs是一个tensor

如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)

如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]

state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

这里有关于LSTM的结构问题:

浅谈Tensorflow 动态双向RNN的输出问题

以上这篇浅谈Tensorflow 动态双向RNN的输出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python制作微信跳一跳辅助
Jan 31 Python
解决pycharm运行出错,代码正确结果不显示的问题
Nov 30 Python
kafka-python批量发送数据的实例
Dec 27 Python
python实现桌面壁纸切换功能
Jan 21 Python
使用python爬取微博数据打造一颗“心”
Jun 28 Python
Python3内置模块random随机方法小结
Jul 13 Python
Python re 模块findall() 函数返回值展现方式解析
Aug 09 Python
python图形绘制奥运五环实例讲解
Sep 14 Python
python如何实现不用装饰器实现登陆器小程序
Dec 14 Python
Python 解决OPEN读文件报错 ,路径以及r的问题
Dec 19 Python
Python 读取有公式cell的结果内容实例方法
Feb 17 Python
Python 图片处理库exifread详解
Feb 25 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
You might like
php生成静态文件的多种方法分享
2012/07/17 PHP
PHPStorm+XDebug进行调试图文教程
2016/06/13 PHP
javascript数组组合成字符串的脚本
2021/01/06 Javascript
cnblogs TagCloud基于jquery的实现代码
2010/06/11 Javascript
基于JQuery的6个Tab选项卡插件
2010/09/03 Javascript
js不能跳转到上一页面的问题解决方法
2013/03/01 Javascript
下拉列表select 由左边框移动到右边示例
2013/12/04 Javascript
jQuery中DOM树操作之复制元素的方法
2015/01/23 Javascript
jQuery实现购物车数字加减效果
2015/03/14 Javascript
jquery实现的3D旋转木马特效代码分享
2015/08/25 Javascript
javascript基础知识讲解
2017/01/11 Javascript
原生js实现新闻列表展开/收起全文功能
2017/01/20 Javascript
Angular2平滑升级到Angular4的步骤详解
2017/03/29 Javascript
使用jQuery实现购物车结算功能
2017/08/15 jQuery
AngularJS中table表格基本操作示例
2017/10/10 Javascript
详解extract-text-webpack-plugin 的使用及安装
2018/06/12 Javascript
js实现按钮开关单机下拉菜单效果
2018/11/22 Javascript
react高阶组件添加和删除props
2019/04/26 Javascript
浅谈vue中resetFields()使用注意事项
2020/08/12 Javascript
vue Treeselect下拉树只能选择第N级元素实现代码
2020/08/31 Javascript
python计算文本文件行数的方法
2015/07/06 Python
python爬取网易云音乐评论
2018/11/16 Python
pandas删除行删除列增加行增加列的实现
2019/07/06 Python
Pytorch实现GoogLeNet的方法
2019/08/18 Python
Python 元组操作总结
2019/09/18 Python
Python爬虫之urllib基础用法教程
2019/10/12 Python
利用matplotlib为图片上添加触发事件进行交互
2020/04/23 Python
HTML5中div、article、section的区别及使用介绍
2013/08/14 HTML / CSS
详解HTML5中的拖放事件(Drag 和 drop)
2016/11/14 HTML / CSS
介绍一下#error预处理
2015/09/25 面试题
如何进行Linux分区优化
2013/02/12 面试题
EJB2和EJB3在架构上的不同点
2014/09/29 面试题
客服文员岗位职责
2013/11/29 职场文书
运动会班级口号
2014/06/09 职场文书
2016年安康杯竞赛活动总结
2016/04/05 职场文书
JS ES6异步解决方案
2021/04/29 Javascript