浅谈Tensorflow 动态双向RNN的输出问题


Posted in Python onJanuary 20, 2020

tf.nn.bidirectional_dynamic_rnn()

函数:

def bidirectional_dynamic_rnn(
  cell_fw, # 前向RNN
  cell_bw, # 后向RNN
  inputs, # 输入
  sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  initial_state_fw=None, # 前向的初始化状态(可选)
  initial_state_bw=None, # 后向的初始化状态(可选)
  dtype=None, # 初始化和输出的数据类型(可选)
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  # 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
  # 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
  scope=None
)

其中,

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设

time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。

output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。

output_state_fw和output_state_bw的类型为LSTMStateTuple。

LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

返回值:

元组:(outputs, output_states)

这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态

def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
  # rnn_size: rnn隐层节点数量
  # sequence_length: 数据的序列长度
  # num_layers:堆叠的rnn cell数量
  # rnn_inputs: 输入tensor
  # keep_prob:
  '''Create the encoding layer'''
  for layer in range(num_layers):
    with tf.variable_scope('encode_{}'.format(layer)):
      cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
 
      cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
 
      enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
                                  rnn_inputs,sequence_length,dtype=tf.float32)
 
  # join outputs since we are using a bidirectional RNN
  enc_output = tf.concat(enc_output,2) 
  return enc_output,enc_state

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。

outputs. outputs是一个tensor

如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)

如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]

state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

这里有关于LSTM的结构问题:

浅谈Tensorflow 动态双向RNN的输出问题

以上这篇浅谈Tensorflow 动态双向RNN的输出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现树的先序、中序、后序排序算法示例
Jun 23 Python
Django自定义manage命令实例代码
Feb 11 Python
Python堆排序原理与实现方法详解
May 11 Python
Python使用min、max函数查找二维数据矩阵中最小、最大值的方法
May 15 Python
10 行 Python 代码教你自动发送短信(不想回复工作邮件妙招)
Oct 11 Python
Python找出微信上删除你好友的人脚本写法
Nov 01 Python
Python文件路径名的操作方法
Oct 30 Python
Python3 读取Word文件方式
Feb 13 Python
python设置代理和添加镜像源的方法
Feb 14 Python
django rest framework serializer返回时间自动格式化方法
Mar 31 Python
Python3.8.2安装包及安装教程图文详解(附安装包)
Nov 28 Python
Python Django获取URL中的数据详解
Nov 01 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
You might like
PHP中使用数组实现堆栈数据结构的代码
2012/02/05 PHP
php编写的简单页面跳转功能实现代码
2013/11/27 PHP
PHP7+Nginx的配置与安装教程详解
2016/05/10 PHP
Laravel5中防止XSS跨站攻击的方法
2016/10/10 PHP
Yii2学习笔记之汉化yii设置表单的描述(属性标签attributeLabels)
2017/02/07 PHP
从javascript语言本身谈项目实战
2006/12/27 Javascript
让firefox支持IE的一些方法的javascript扩展函数代码
2010/01/02 Javascript
JS关键字变色实现思路及代码
2013/02/21 Javascript
javascript数组快速打乱重排的方法
2014/01/02 Javascript
node.js使用require()函数加载模块
2014/11/26 Javascript
判断是否存在子节点的实现代码
2016/05/18 Javascript
jQuery封装的屏幕居中提示信息代码
2016/06/08 Javascript
zepto与jquery的区别及zepto的不同使用8条小结
2016/07/28 Javascript
用jQuery.ajaxSetup实现对请求和响应数据的过滤
2016/12/20 Javascript
基于Vue实现tab栏切换内容不断实时刷新数据功能
2017/04/13 Javascript
详解React-Native全球化多语言切换工具库react-native-i18n
2017/11/03 Javascript
Vue resource三种请求格式和万能测试地址
2018/09/26 Javascript
NUXT SSR初级入门笔记(小结)
2019/12/16 Javascript
vue element table中自定义一些input的验证操作
2020/07/18 Javascript
Python批量合并有合并单元格的Excel文件详解
2018/04/05 Python
Python爬虫实战之12306抢票开源
2019/01/24 Python
python调用HEG工具批量处理MODIS数据的方法及注意事项
2020/02/18 Python
Django 权限管理(permissions)与用户组(group)详解
2020/11/30 Python
Python排序函数的使用方法详解
2020/12/11 Python
python爬取2021猫眼票房字体加密实例
2021/02/19 Python
彻底弄明白CSS3的Media Queries(跨平台设计)
2010/07/27 HTML / CSS
美国现代家具购物网站:LexMod
2019/01/09 全球购物
用C或者C++语言实现SOCKET通信
2015/02/24 面试题
实习医生自我评价
2013/09/22 职场文书
个人生活学习自我评价范文
2013/11/26 职场文书
校园公益广告语
2014/03/13 职场文书
教师新年寄语
2014/04/03 职场文书
文明美德伴我成长演讲稿
2014/05/12 职场文书
会计学毕业生求职信
2014/06/25 职场文书
员工工作心得体会
2019/05/07 职场文书
2019年工作总结范文
2019/05/21 职场文书