浅谈Tensorflow 动态双向RNN的输出问题


Posted in Python onJanuary 20, 2020

tf.nn.bidirectional_dynamic_rnn()

函数:

def bidirectional_dynamic_rnn(
  cell_fw, # 前向RNN
  cell_bw, # 后向RNN
  inputs, # 输入
  sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  initial_state_fw=None, # 前向的初始化状态(可选)
  initial_state_bw=None, # 后向的初始化状态(可选)
  dtype=None, # 初始化和输出的数据类型(可选)
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  # 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
  # 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
  scope=None
)

其中,

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设

time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。

output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。

output_state_fw和output_state_bw的类型为LSTMStateTuple。

LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

返回值:

元组:(outputs, output_states)

这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态

def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
  # rnn_size: rnn隐层节点数量
  # sequence_length: 数据的序列长度
  # num_layers:堆叠的rnn cell数量
  # rnn_inputs: 输入tensor
  # keep_prob:
  '''Create the encoding layer'''
  for layer in range(num_layers):
    with tf.variable_scope('encode_{}'.format(layer)):
      cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
 
      cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
 
      enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
                                  rnn_inputs,sequence_length,dtype=tf.float32)
 
  # join outputs since we are using a bidirectional RNN
  enc_output = tf.concat(enc_output,2) 
  return enc_output,enc_state

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。

outputs. outputs是一个tensor

如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)

如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]

state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

这里有关于LSTM的结构问题:

浅谈Tensorflow 动态双向RNN的输出问题

以上这篇浅谈Tensorflow 动态双向RNN的输出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取豆瓣电影简介代码分享
Jan 16 Python
Python迭代用法实例教程
Sep 08 Python
python将图片文件转换成base64编码的方法
Mar 14 Python
解析Python中的异常处理
Apr 28 Python
为Python的Tornado框架配置使用Jinja2模板引擎的方法
Jun 30 Python
python脚本监控Tomcat服务器的方法
Jul 06 Python
python之文件读取一行一行的方法
Jul 12 Python
对python读取zip压缩文件里面的csv数据实例详解
Feb 08 Python
python根据txt文本批量创建文件夹
Dec 08 Python
Pandas 重塑(stack)和轴向旋转(pivot)的实现
Jul 22 Python
Anaconda配置pytorch-gpu虚拟环境的图文教程
Apr 16 Python
PYTHON InceptionV3模型的复现详解
May 06 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
You might like
php set_time_limit(0) 设置程序执行时间的函数
2010/05/26 PHP
初识Laravel
2014/10/30 PHP
javascript delete 使用示例代码
2010/03/29 Javascript
CheckBoxList多选样式jquery、C#获取选择项
2013/09/06 Javascript
jquery配合css简单实现返回顶部效果
2013/09/30 Javascript
JavaScript对象反射用法实例
2015/04/17 Javascript
JavaScript实现简单的数字倒计时
2015/05/15 Javascript
jQuery插件imgAreaSelect基础讲解
2017/05/26 jQuery
js禁止表单重复提交
2017/08/29 Javascript
Angular7中创建组件/自定义指令/管道的方法实例详解
2019/04/02 Javascript
微信小程序常见页面跳转操作简单示例
2019/05/01 Javascript
react 组件传值的三种方法
2019/06/03 Javascript
layer设置maxWidth及maxHeight解决方案
2019/07/26 Javascript
关于layui flow loading占位图的实现方法
2019/09/21 Javascript
JS函数本身的作用域实例分析
2020/03/16 Javascript
javascript设计模式 ? 抽象工厂模式原理与应用实例分析
2020/04/09 Javascript
简单的连接MySQL与Python的Bottle框架的方法
2015/04/30 Python
python中实现将多个print输出合成一个数组
2018/04/19 Python
Django1.9 加载通过ImageField上传的图片方法
2018/05/25 Python
python 读取文件并替换字段的实例
2018/07/12 Python
python3.6数独问题的解决
2019/01/21 Python
Python处理时间日期坐标轴过程详解
2019/06/25 Python
Python时间序列缺失值的处理方法(日期缺失填充)
2019/08/11 Python
基于python及pytorch中乘法的使用详解
2019/12/27 Python
python调用jenkinsAPI构建jenkins,并传递参数的示例
2020/12/09 Python
浅谈css3中的前缀
2016/07/20 HTML / CSS
在canvas上实现元素图片镜像翻转动画效果的方法
2018/03/20 HTML / CSS
实习教师自我鉴定
2013/12/12 职场文书
工作人员思想汇报
2014/01/09 职场文书
缅怀先烈演讲稿
2014/09/03 职场文书
2014镇党委班子对照检查材料思想汇报
2014/09/23 职场文书
2015年控辍保学工作总结
2015/05/18 职场文书
python基于tkinter制作m3u8视频下载工具
2021/04/24 Python
JVM钩子函数的使用场景详解
2021/08/23 Java/Android
如何通过cmd 连接阿里云服务器
2022/04/18 Servers
详解如何使用Nginx解决跨域问题
2022/05/06 Servers