浅谈Tensorflow 动态双向RNN的输出问题


Posted in Python onJanuary 20, 2020

tf.nn.bidirectional_dynamic_rnn()

函数:

def bidirectional_dynamic_rnn(
  cell_fw, # 前向RNN
  cell_bw, # 后向RNN
  inputs, # 输入
  sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  initial_state_fw=None, # 前向的初始化状态(可选)
  initial_state_bw=None, # 后向的初始化状态(可选)
  dtype=None, # 初始化和输出的数据类型(可选)
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  # 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
  # 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
  scope=None
)

其中,

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设

time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。

output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。

output_state_fw和output_state_bw的类型为LSTMStateTuple。

LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

返回值:

元组:(outputs, output_states)

这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态

def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
  # rnn_size: rnn隐层节点数量
  # sequence_length: 数据的序列长度
  # num_layers:堆叠的rnn cell数量
  # rnn_inputs: 输入tensor
  # keep_prob:
  '''Create the encoding layer'''
  for layer in range(num_layers):
    with tf.variable_scope('encode_{}'.format(layer)):
      cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
 
      cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
 
      enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
                                  rnn_inputs,sequence_length,dtype=tf.float32)
 
  # join outputs since we are using a bidirectional RNN
  enc_output = tf.concat(enc_output,2) 
  return enc_output,enc_state

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。

outputs. outputs是一个tensor

如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)

如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]

state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

这里有关于LSTM的结构问题:

浅谈Tensorflow 动态双向RNN的输出问题

以上这篇浅谈Tensorflow 动态双向RNN的输出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实例之wxpython中Frame使用方法
Jun 09 Python
python网络编程学习笔记(七):HTML和XHTML解析(HTMLParser、BeautifulSoup)
Jun 09 Python
最大K个数问题的Python版解法总结
Jun 16 Python
解决Python3中的中文字符编码的问题
Jul 18 Python
使用Python的toolz库开始函数式编程的方法
Nov 15 Python
Python使用Shelve保存对象方法总结
Jan 28 Python
python 接口实现 供第三方调用的例子
Aug 13 Python
python常用排序算法的实现代码
Nov 08 Python
深入了解python列表(LIST)
Jun 08 Python
python 实现控制鼠标键盘
Nov 27 Python
python飞机大战游戏实例讲解
Dec 04 Python
C++和python实现阿姆斯特朗数字查找实例代码
Dec 07 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
You might like
php 多个submit提交表单 处理方法
2009/07/07 PHP
PHP手机号码归属地查询代码(API接口/mysql)
2012/09/04 PHP
浅析ThinkPHP中的pathinfo模式和URL重写
2014/01/06 PHP
form表单传递数组数据、php脚本接收的实例
2017/02/09 PHP
PHP判断一个数组是另一个数组子集的方法详解
2017/07/31 PHP
thinkPHP框架实现的短信接口验证码功能示例
2018/06/20 PHP
Bootstrap实现下拉菜单效果
2016/04/29 Javascript
jQuery实现的checkbox级联选择下拉菜单效果示例
2016/12/26 Javascript
node.js入门教程之querystring模块的使用方法
2017/02/27 Javascript
详解使用fetch发送post请求时的参数处理
2017/04/05 Javascript
JS操作时间 - UNIX时间戳的简单介绍(必看篇)
2017/08/16 Javascript
在JS循环中使用async/await的方法
2018/10/12 Javascript
Vue发布订阅模式实现过程图解
2020/04/30 Javascript
Python通过select实现异步IO的方法
2015/06/04 Python
Python实现字典按照value进行排序的方法分析
2017/12/23 Python
python爬虫使用cookie登录详解
2017/12/27 Python
解决python打不开文件(文件不存在)的问题
2019/02/18 Python
Python编写合并字典并实现敏感目录的小脚本
2019/02/26 Python
不到20行代码用Python做一个智能聊天机器人
2019/04/19 Python
解决reload(sys)后print失效的问题
2020/04/25 Python
Python数据可视化实现多种图例代码详解
2020/07/14 Python
Python中对象的比较操作==和is区别详析
2021/02/12 Python
css3简单练习实现遨游浏览器logo的绘制
2013/01/30 HTML / CSS
行政部主管岗位职责
2013/12/28 职场文书
运动会通讯稿200字
2014/02/16 职场文书
教师对学生的评语
2014/04/28 职场文书
2014年会策划方案
2014/05/11 职场文书
物理学专业自荐信
2014/06/11 职场文书
创先争优活动党员公开承诺书
2014/08/29 职场文书
科级干部群众路线教育实践活动对照检查材料思想汇报
2014/09/20 职场文书
2015年大学生社会实践评语
2015/03/26 职场文书
2015民办小学年度工作总结
2015/05/26 职场文书
2016年暑假家长对孩子评语
2015/12/01 职场文书
2019军训心得体会
2019/06/27 职场文书
2019入党申请书范文3篇
2019/08/21 职场文书
灵能百分百第三季什么时候来?
2022/03/15 日漫