浅谈Tensorflow 动态双向RNN的输出问题


Posted in Python onJanuary 20, 2020

tf.nn.bidirectional_dynamic_rnn()

函数:

def bidirectional_dynamic_rnn(
  cell_fw, # 前向RNN
  cell_bw, # 后向RNN
  inputs, # 输入
  sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  initial_state_fw=None, # 前向的初始化状态(可选)
  initial_state_bw=None, # 后向的初始化状态(可选)
  dtype=None, # 初始化和输出的数据类型(可选)
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  # 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
  # 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
  scope=None
)

其中,

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设

time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。

output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。

output_state_fw和output_state_bw的类型为LSTMStateTuple。

LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

返回值:

元组:(outputs, output_states)

这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态

def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
  # rnn_size: rnn隐层节点数量
  # sequence_length: 数据的序列长度
  # num_layers:堆叠的rnn cell数量
  # rnn_inputs: 输入tensor
  # keep_prob:
  '''Create the encoding layer'''
  for layer in range(num_layers):
    with tf.variable_scope('encode_{}'.format(layer)):
      cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
 
      cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
 
      enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
                                  rnn_inputs,sequence_length,dtype=tf.float32)
 
  # join outputs since we are using a bidirectional RNN
  enc_output = tf.concat(enc_output,2) 
  return enc_output,enc_state

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。

outputs. outputs是一个tensor

如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)

如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]

state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

这里有关于LSTM的结构问题:

浅谈Tensorflow 动态双向RNN的输出问题

以上这篇浅谈Tensorflow 动态双向RNN的输出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现根据ip地址反向查找主机名称的方法
Apr 29 Python
分析并输出Python代码依赖的库的实现代码
Aug 09 Python
python实现简单购物商城
May 21 Python
Python基于numpy灵活定义神经网络结构的方法
Aug 19 Python
tensorflow识别自己手写数字
Mar 14 Python
python装饰器深入学习
Apr 06 Python
Python使用Dijkstra算法实现求解图中最短路径距离问题详解
May 16 Python
Django中反向生成models.py的实例讲解
May 30 Python
python存储16bit和32bit图像的实例
Dec 05 Python
python实现中文文本分句的例子
Jul 15 Python
k-means 聚类算法与Python实现代码
Jun 01 Python
详解pandas赋值失败问题解决
Nov 29 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
You might like
在Windows下编译适用于PHP 5.2.12及5.2.13的eAccelerator.dll(附下载)
2010/05/04 PHP
关于PHP session 存储方式的详细介绍
2013/06/25 PHP
PHP连接access数据库
2015/03/27 PHP
PHP实现远程下载文件到本地
2015/05/17 PHP
cookie的复制与使用记住用户名实现代码
2013/11/04 Javascript
jquery自定义滚动条插件示例分享
2014/02/21 Javascript
js判断浏览器类型为ie6时不执行
2014/06/15 Javascript
基于jQuery Tipso插件实现消息提示框特效
2016/03/16 Javascript
前端js文件合并的三种方式推荐
2016/05/19 Javascript
js H5 canvas投篮小游戏
2016/08/18 Javascript
node打造微信个人号机器人的方法示例
2018/04/26 Javascript
基于jQuery实现的设置文本区域的光标位置
2018/06/15 jQuery
layui-laydate时间日历控件使用方法详解
2018/11/15 Javascript
vue+php实现的微博留言功能示例
2019/03/16 Javascript
非常漂亮的js烟花效果
2020/03/10 Javascript
TypeScript魔法堂之枚举的超实用手册
2020/10/29 Javascript
在vue中通过render函数给子组件设置ref操作
2020/11/17 Vue.js
[02:22]《新闻直播间》2017年08月14日
2017/08/15 DOTA
简单的抓取淘宝图片的Python爬虫
2014/12/25 Python
Python编程中的文件读写及相关的文件对象方法讲解
2016/01/19 Python
python 读写txt文件 json文件的实现方法
2016/10/22 Python
python使用正则表达式替换匹配成功的组并输出替换的次数
2017/11/22 Python
Python使用min、max函数查找二维数据矩阵中最小、最大值的方法
2018/05/15 Python
解决pandas 作图无法显示中文的问题
2018/05/24 Python
连接pandas以及数组转pandas的方法
2019/06/28 Python
Pytorch中的variable, tensor与numpy相互转化的方法
2019/10/10 Python
windows python3安装Jupyter Notebooks教程
2020/04/13 Python
Python可以实现栈的结构吗
2020/05/27 Python
Selenium执行完毕未关闭chromedriver/geckodriver进程的解决办法(java版+python版)
2020/12/07 Python
TensorFlow2.0使用keras训练模型的实现
2021/02/20 Python
法国高保真音响和家庭影院商店:Son Video
2019/04/26 全球购物
考试不及格检讨书
2014/01/09 职场文书
服务员岗位职责
2014/01/29 职场文书
2015年妇女工作总结
2015/05/14 职场文书
欠款起诉书范文
2015/05/19 职场文书
Python通过m3u8文件下载合并ts视频的操作
2021/04/16 Python