浅谈Tensorflow 动态双向RNN的输出问题


Posted in Python onJanuary 20, 2020

tf.nn.bidirectional_dynamic_rnn()

函数:

def bidirectional_dynamic_rnn(
  cell_fw, # 前向RNN
  cell_bw, # 后向RNN
  inputs, # 输入
  sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  initial_state_fw=None, # 前向的初始化状态(可选)
  initial_state_bw=None, # 后向的初始化状态(可选)
  dtype=None, # 初始化和输出的数据类型(可选)
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  # 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
  # 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
  scope=None
)

其中,

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设

time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。

output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。

output_state_fw和output_state_bw的类型为LSTMStateTuple。

LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

返回值:

元组:(outputs, output_states)

这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态

def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
  # rnn_size: rnn隐层节点数量
  # sequence_length: 数据的序列长度
  # num_layers:堆叠的rnn cell数量
  # rnn_inputs: 输入tensor
  # keep_prob:
  '''Create the encoding layer'''
  for layer in range(num_layers):
    with tf.variable_scope('encode_{}'.format(layer)):
      cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
 
      cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
 
      enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
                                  rnn_inputs,sequence_length,dtype=tf.float32)
 
  # join outputs since we are using a bidirectional RNN
  enc_output = tf.concat(enc_output,2) 
  return enc_output,enc_state

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。

outputs. outputs是一个tensor

如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)

如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]

state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

这里有关于LSTM的结构问题:

浅谈Tensorflow 动态双向RNN的输出问题

以上这篇浅谈Tensorflow 动态双向RNN的输出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之字典,你还记得吗?
Sep 20 Python
使用基于Python的Tornado框架的HTTP客户端的教程
Apr 24 Python
编写Python脚本把sqlAlchemy对象转换成dict的教程
May 29 Python
PyQt5实现拖放功能
Apr 25 Python
Python 网络编程之TCP客户端/服务端功能示例【基于socket套接字】
Oct 12 Python
flask 框架操作MySQL数据库简单示例
Feb 02 Python
Tensorflow中tf.ConfigProto()的用法详解
Feb 06 Python
使用python从三个角度解决josephus问题的方法
Mar 27 Python
python安装后的目录在哪里
Jun 21 Python
tensorflow图像裁剪进行数据增强操作
Jun 30 Python
通俗易懂了解Python装饰器原理
Sep 17 Python
python脚本框架webpy的url映射详解
Nov 20 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
You might like
比较简单实用的PHP无限分类源码分享(思路不错)
2011/10/13 PHP
使用JSON实现数据的跨域传输的php代码
2011/12/20 PHP
PHP中去掉字符串首尾空格的方法
2012/05/19 PHP
PHP 获取远程文件大小的3种解决方法
2013/07/11 PHP
php内存缓存实现方法
2015/01/24 PHP
php+flash+jQuery多图片上传源码分享
2020/07/27 PHP
php 删除一维数组中某一个值元素的操作方法
2018/02/01 PHP
php + WebUploader实现图片批量上传功能
2019/05/06 PHP
发现的以前不知道的函数
2006/09/19 Javascript
js focus不起作用的解决方法(主要是因为dom元素是否加载完成)
2010/11/05 Javascript
window.open不被拦截的实现代码
2012/08/22 Javascript
再JavaScript的jQuery库中编写动画效果的指南
2015/08/13 Javascript
谷歌Chrome浏览器扩展程序开发小记
2016/01/06 Javascript
客户端验证用户名和密码的方法详解
2016/06/16 Javascript
JS设计模式之惰性模式(二)
2017/09/29 Javascript
微信小程序用户信息encryptedData详解
2018/08/24 Javascript
vue.js 实现点击按钮动态添加li的方法
2018/09/07 Javascript
vue计算属性computed的使用方法示例
2019/03/13 Javascript
Node.js API详解之 string_decoder用法实例分析
2020/04/29 Javascript
解决vue-router路由拦截造成死循环问题
2020/08/05 Javascript
Python实现二叉树结构与进行二叉树遍历的方法详解
2016/05/24 Python
详解python开发环境搭建
2016/12/16 Python
python分别打包出32位和64位应用程序
2020/02/18 Python
如何使用css3实现一个类在线直播的队列动画的示例代码
2020/06/17 HTML / CSS
普通PHP程序员笔试题
2016/01/01 面试题
Shell脚本如何向终端输出信息
2014/04/25 面试题
儿科主治医生个人求职信
2013/09/23 职场文书
应届毕业生求职信范文
2014/07/07 职场文书
党员干部观看《周恩来四个昼夜》思想汇报
2014/09/10 职场文书
2014幼儿园教师个人工作总结
2014/11/08 职场文书
2014年后勤工作总结
2014/11/18 职场文书
2015年房产销售工作总结范文
2015/05/22 职场文书
导游词之茶卡盐湖
2019/11/26 职场文书
基于python的matplotlib制作双Y轴图
2021/04/20 Python
十大动画制作软件,Adobe产品上榜两款,第一是行业标准软件
2022/03/18 杂记
JS前端使用canvas实现物体的点选示例
2022/08/05 Javascript