浅谈Tensorflow 动态双向RNN的输出问题


Posted in Python onJanuary 20, 2020

tf.nn.bidirectional_dynamic_rnn()

函数:

def bidirectional_dynamic_rnn(
  cell_fw, # 前向RNN
  cell_bw, # 后向RNN
  inputs, # 输入
  sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  initial_state_fw=None, # 前向的初始化状态(可选)
  initial_state_bw=None, # 后向的初始化状态(可选)
  dtype=None, # 初始化和输出的数据类型(可选)
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  # 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
  # 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
  scope=None
)

其中,

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设

time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。

output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。

output_state_fw和output_state_bw的类型为LSTMStateTuple。

LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

返回值:

元组:(outputs, output_states)

这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态

def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
  # rnn_size: rnn隐层节点数量
  # sequence_length: 数据的序列长度
  # num_layers:堆叠的rnn cell数量
  # rnn_inputs: 输入tensor
  # keep_prob:
  '''Create the encoding layer'''
  for layer in range(num_layers):
    with tf.variable_scope('encode_{}'.format(layer)):
      cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
 
      cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
 
      enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
                                  rnn_inputs,sequence_length,dtype=tf.float32)
 
  # join outputs since we are using a bidirectional RNN
  enc_output = tf.concat(enc_output,2) 
  return enc_output,enc_state

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。

outputs. outputs是一个tensor

如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)

如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]

state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

这里有关于LSTM的结构问题:

浅谈Tensorflow 动态双向RNN的输出问题

以上这篇浅谈Tensorflow 动态双向RNN的输出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python构建自定义回调函数详解
Jun 20 Python
利用numpy和pandas处理csv文件中的时间方法
Apr 19 Python
用pycharm开发django项目示例代码
Oct 24 Python
Python生成MD5值的两种方法实例分析
Apr 26 Python
django表单的Widgets使用详解
Jul 22 Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 Python
Pytorch中index_select() 函数的实现理解
Nov 19 Python
python爬取本站电子书信息并入库的实现代码
Jan 20 Python
解决pyCharm中 module 调用失败的问题
Feb 12 Python
Python 开发工具PyCharm安装教程图文详解(新手必看)
Feb 28 Python
Python第三方包之DingDingBot钉钉机器人
Apr 09 Python
在django中实现choices字段获取对应字段值
Jul 12 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
You might like
ThinkPHP采用原生query实现关联查询left join实例
2014/12/02 PHP
PHP实现长文章分页实例代码(附源码)
2016/02/03 PHP
js中几种去掉字串左右空格的方法
2006/12/25 Javascript
javascript设计模式 封装和信息隐藏(上)
2012/07/24 Javascript
JS添加删除一组文本框并对输入信息加以验证判断其正确性
2013/04/11 Javascript
Javascript实现重力弹跳拖拽运动效果示例
2013/06/28 Javascript
jQuery大于号(>)选择器的作用解释
2015/01/13 Javascript
4种JavaScript实现简单tab选项卡切换的方法
2016/01/06 Javascript
浅谈时钟的生成(js手写简洁代码)
2016/08/20 Javascript
Angular实现跨域(搜索框的下拉列表)
2017/02/16 Javascript
JS区分Object与Aarry的六种方法总结
2017/02/27 Javascript
AngularJS2中一种button切换效果的实现方法(二)
2017/03/27 Javascript
angularjs之$timeout指令详解
2017/06/13 Javascript
js限制input只能输入有效的数字(第一个不能是小数点)
2018/09/28 Javascript
node.js的http.createServer过程深入解析
2019/06/06 Javascript
[04:56]经典回顾:前Ehome 与 前LGD
2015/02/26 DOTA
[01:05:29]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Aster BO3 第二场 1月24日
2021/03/11 DOTA
Python数据集切分实例
2018/12/08 Python
python导包的几种方法(自定义包的生成以及导入详解)
2019/07/15 Python
Python制作简易版小工具之计算天数的实现思路
2020/02/13 Python
jupyter notebook 增加kernel教程
2020/04/10 Python
Python生成器next方法和send方法区别详解
2020/05/30 Python
JAVA SWT事件四种写法实例解析
2020/06/05 Python
pycharm 实现本地写代码,服务器运行的操作
2020/06/08 Python
五分钟带你搞懂python 迭代器与生成器
2020/08/30 Python
python实现逻辑回归的示例
2020/10/09 Python
4款Python 类型检查工具,你选择哪个呢?
2020/10/30 Python
pycharm 配置svn的图文教程(手把手教你)
2021/01/15 Python
Perricone MD裴礼康美国官网:抗衰老护肤品
2016/09/26 全球购物
美国嘻哈首饰购物网站:Hip Hop Bling
2016/12/30 全球购物
英国50岁以上人群的交友网站:Ourtime
2018/03/28 全球购物
理财投资建议书
2014/03/12 职场文书
幼儿园2014年度工作总结
2014/11/10 职场文书
初中作文评语
2014/12/25 职场文书
2016圣诞节贺卡寄语
2015/12/07 职场文书
springmvc直接不经过controller访问WEB-INF中的页面问题
2022/02/24 Java/Android