关于tf.nn.dynamic_rnn返回值详解


Posted in Python onJanuary 20, 2020

函数原型

tf.nn.dynamic_rnn(
  cell,
  inputs,
  sequence_length=None,
  initial_state=None,
  dtype=None,
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  scope=None
)

实例讲解:

import tensorflow as tf
import numpy as np
 
n_steps = 2
n_inputs = 3
n_neurons = 5
 
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
 
seq_length = tf.placeholder(tf.int32, [None])
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32,
                  sequence_length=seq_length)
 
init = tf.global_variables_initializer()
 
X_batch = np.array([
    # step 0   step 1
    [[0, 1, 2], [9, 8, 7]], # instance 1
    [[3, 4, 5], [0, 0, 0]], # instance 2 (padded with zero vectors)
    [[6, 7, 8], [6, 5, 4]], # instance 3
    [[9, 0, 1], [3, 2, 1]], # instance 4
  ])
seq_length_batch = np.array([2, 1, 2, 2])
 
with tf.Session() as sess:
  init.run()
  outputs_val, states_val = sess.run(
    [outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch})
  print("outputs_val.shape:", outputs_val.shape, "states_val.shape:", states_val.shape)
  print("outputs_val:", outputs_val, "states_val:", states_val)

log info:

outputs_val.shape: (4, 2, 5) states_val.shape: (4, 5)
outputs_val: 
[[[ 0.53073734 -0.61281306 -0.5437517  0.7320347 -0.6109526 ]
 [ 0.99996936 0.99990636 -0.9867181  0.99726075 -0.99999976]]
 
 [[ 0.9931584  0.5877845 -0.9100412  0.988892  -0.9982337 ]
 [ 0.     0.     0.     0.     0.    ]]
 
 [[ 0.99992317 0.96815354 -0.985101  0.9995968 -0.9999936 ]
 [ 0.99948144 0.9998127 -0.57493806 0.91015154 -0.99998355]]
 
 [[ 0.99999255 0.9998929  0.26732785 0.36024097 -0.99991137]
 [ 0.98875254 0.9922327  0.6505734  0.4732064 -0.9957567 ]]] 
states_val:
 [[ 0.99996936 0.99990636 -0.9867181  0.99726075 -0.99999976]
 [ 0.9931584  0.5877845 -0.9100412  0.988892  -0.9982337 ]
 [ 0.99948144 0.9998127 -0.57493806 0.91015154 -0.99998355]
 [ 0.98875254 0.9922327  0.6505734  0.4732064 -0.9957567 ]]

首先输入X是一个 [batch_size,step,input_size] = [4,2,3] 的tensor,注意我们这里调用的是BasicRNNCell,只有一层循环网络,outputs是最后一层每个step的输出,它的结构是[batch_size,step,n_neurons] = [4,2,5],states是每一层的最后那个step的输出,由于本例中,我们的循环网络只有一个隐藏层,所以它就代表这一层的最后那个step的输出,因此它和step的大小是没有关系的,我们的X有4个样本组成,输出神经元大小n_neurons是5,因此states的结构就是[batch_size,n_neurons] = [4,5],最后我们观察数据,states的每条数据正好就是outputs的最后一个step的输出。

下面我们继续讲解多个隐藏层的情况,这里是三个隐藏层,注意我们这里仍然是调用BasicRNNCell

import tensorflow as tf
import numpy as np
 
n_steps = 2
n_inputs = 3
n_neurons = 5
n_layers = 3
 
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])
 
layers = [tf.contrib.rnn.BasicRNNCell(num_units=n_neurons,
                   activation=tf.nn.relu)
     for layer in range(n_layers)]
multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32, sequence_length=seq_length)
 
init = tf.global_variables_initializer()
 
X_batch = np.array([
    # step 0   step 1
    [[0, 1, 2], [9, 8, 7]], # instance 1
    [[3, 4, 5], [0, 0, 0]], # instance 2 (padded with zero vectors)
    [[6, 7, 8], [6, 5, 4]], # instance 3
    [[9, 0, 1], [3, 2, 1]], # instance 4
  ])
 
seq_length_batch = np.array([2, 1, 2, 2])
 
with tf.Session() as sess:
  init.run()
  outputs_val, states_val = sess.run(
    [outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch})
  print("outputs_val.shape:", outputs, "states_val.shape:", states)
  print("outputs_val:", outputs_val, "states_val:", states_val)

log info:

outputs_val.shape: 
Tensor("rnn/transpose_1:0", shape=(?, 2, 5), dtype=float32) 
 
states_val.shape: 
(<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 5) dtype=float32>, 
 <tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 5) dtype=float32>, 
 <tf.Tensor 'rnn/while/Exit_5:0' shape=(?, 5) dtype=float32>)
 
outputs_val:
 [[[0.     0.     0.     0.     0.    ]
 [0.     0.18740742 0.     0.2997518 0.    ]]
 
 [[0.     0.07222144 0.     0.11551574 0.    ]
 [0.     0.     0.     0.     0.    ]]
 
 [[0.     0.13463384 0.     0.21534224 0.    ]
 [0.03702604 0.18443246 0.     0.34539366 0.    ]]
 
 [[0.     0.54511094 0.     0.8718864 0.    ]
 [0.5382122 0.     0.04396425 0.4040263 0.    ]]] 
 
states_val:
 (array([[0.    , 0.83723307, 0.    , 0.    , 2.8518028 ],
    [0.    , 0.1996038 , 0.    , 0.    , 1.5456247 ],
    [0.    , 1.1372368 , 0.    , 0.    , 0.832613 ],
    [0.    , 0.7904129 , 2.4675028 , 0.    , 0.36980057]],
   dtype=float32), 
 array([[0.6524607 , 0.    , 0.    , 0.    , 0.    ],
    [0.25143963, 0.    , 0.    , 0.    , 0.    ],
    [0.5010576 , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.3166597 , 0.4545995 , 0.    , 0.    ]],
   dtype=float32), 
 array([[0.    , 0.18740742, 0.    , 0.2997518 , 0.    ],
    [0.    , 0.07222144, 0.    , 0.11551574, 0.    ],
    [0.03702604, 0.18443246, 0.    , 0.34539366, 0.    ],
    [0.5382122 , 0.    , 0.04396425, 0.4040263 , 0.    ]],
   dtype=float32))

我们说过,outputs是最后一层的输出,即 [batch_size,step,n_neurons] = [4,2,5]

states是每一层的最后一个step的输出,即三个结构为 [batch_size,n_neurons] = [4,5] 的tensor

继续观察数据,states中的最后一个array,正好是outputs的最后那个step的输出

下面我们继续讲当由BasicLSTMCell构造单元工厂的时候,只讲多层的情况,我们只需要将上面的BasicRNNCell替换成BasicLSTMCell就行了,打印信息如下:

outputs_val.shape: 
Tensor("rnn/transpose_1:0", shape=(?, 2, 5), dtype=float32) 
 
states_val.shape:
(LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 5) dtype=float32>), 
LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(?, 5) dtype=float32>), 
LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(?, 5) dtype=float32>))
 
outputs_val: 
[[[1.2949290e-04 0.0000000e+00 2.7623639e-04 0.0000000e+00 0.0000000e+00]
 [9.4675866e-05 0.0000000e+00 2.0214770e-04 0.0000000e+00 0.0000000e+00]]
 
 [[4.3100454e-06 4.2123037e-07 1.4312843e-06 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
 
 [[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
 
 [[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]] 
 
states_val: 
(LSTMStateTuple(
c=array([[0.    , 0.    , 0.04676079, 0.04284539, 0.    ],
    [0.    , 0.    , 0.0115245 , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ]],
   dtype=float32), 
h=array([[0.    , 0.    , 0.00035096, 0.04284406, 0.    ],
    [0.    , 0.    , 0.00142574, 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ]],
   dtype=float32)), 
LSTMStateTuple(
c=array([[0.0000000e+00, 1.0477135e-02, 4.9871090e-03, 8.2785974e-04,
    0.0000000e+00],
    [0.0000000e+00, 2.3306280e-04, 0.0000000e+00, 9.9445322e-05,
    5.9535629e-05],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32), 
h=array([[0.00000000e+00, 5.23016974e-03, 2.47756205e-03, 4.11730434e-04,
    0.00000000e+00],
    [0.00000000e+00, 1.16522635e-04, 0.00000000e+00, 4.97301044e-05,
    2.97713632e-05],
    [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
    [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00]], dtype=float32)), 
LSTMStateTuple(
c=array([[1.8937115e-04, 0.0000000e+00, 4.0442235e-04, 0.0000000e+00,
    0.0000000e+00],
    [8.6200516e-06, 8.4243663e-07, 2.8625946e-06, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32), 
h=array([[9.4675866e-05, 0.0000000e+00, 2.0214770e-04, 0.0000000e+00,
    0.0000000e+00],
    [4.3100454e-06, 4.2123037e-07, 1.4312843e-06, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32)))

我们先看看LSTM单元的结构

关于tf.nn.dynamic_rnn返回值详解

如果您不查看框内的内容,LSTM单元看起来与常规单元格完全相同,除了它的状态分为两个向量:h(t)和c(t)。你可以将h(t)视为短期状态,将c(t)视为长期状态。

因此我们的states包含三个LSTMStateTuple,每一个表示每一层的最后一个step的输出,这个输出有两个信息,一个是h表示短期记忆信息,一个是c表示长期记忆信息。维度都是[batch_size,n_neurons] = [4,5],states的最后一个LSTMStateTuple中的h就是outputs的最后一个step的输出

以上这篇关于tf.nn.dynamic_rnn返回值详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的日志模块logging
Jun 19 Python
Python基于回溯法子集树模板解决数字组合问题实例
Sep 02 Python
正确理解Python中if __name__ == '__main__'
Jan 24 Python
Django REST framework 分页的实现代码
Jun 19 Python
详解10个可以快速用Python进行数据分析的小技巧
Jun 24 Python
Python搭建代理IP池实现获取IP的方法
Oct 27 Python
Python实现随机取一个矩阵数组的某几行
Nov 26 Python
python将数组n等分的实例
Dec 02 Python
Python求两个字符串最长公共子序列代码实例
Mar 05 Python
如何在Python 游戏中模拟引力
Mar 27 Python
python 服务器运行代码报错ModuleNotFoundError的解决办法
Sep 16 Python
python如何为list实现find方法
May 30 Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
You might like
php一行代码获取文件后缀名实例分析
2014/11/12 PHP
PHP操作FTP类 (上传、下载、移动、创建等)
2016/03/31 PHP
利用CSS、JavaScript及Ajax实现高效的图片预加载
2013/10/16 Javascript
Knockout text绑定DOM的使用方法
2013/11/15 Javascript
浅析Node在构建超媒体API中的作用
2014/07/30 Javascript
jQuery中的jQuery()方法用法分析
2014/12/27 Javascript
JavaScript实现倒计时代码段Item1(非常实用)
2015/11/03 Javascript
JavaScript的字符串方法汇总
2016/07/31 Javascript
Javascript从数组中随机取出不同元素的两种方法
2016/09/22 Javascript
微信小程序 开发指南详解
2016/09/27 Javascript
js实现百度地图定位于地址逆解析,显示自己当前的地理位置
2016/12/08 Javascript
Javascript基础回顾之(二) js作用域
2017/01/31 Javascript
深入理解Javascript箭头函数中的this
2017/02/13 Javascript
微信小程序 本地数据读取实例
2017/04/27 Javascript
vue.js删除动态绑定的radio的指定项
2017/06/02 Javascript
JS中的三个循环小结
2017/06/20 Javascript
js指定步长实现单方向匀速运动
2017/07/17 Javascript
分享5个小技巧让你写出更好的 JavaScript 条件语句
2018/10/20 Javascript
element vue Array数组和Map对象的添加与删除操作
2018/11/14 Javascript
Python collections模块实例讲解
2014/04/07 Python
Python中文字符串截取问题
2015/06/15 Python
Python实现读取字符串按列分配后按行输出示例
2018/04/17 Python
详解Django+uwsgi+Nginx上线最佳实战
2019/03/14 Python
ipad上运行python的方法步骤
2019/10/12 Python
python实现经纬度采样的示例代码
2020/12/10 Python
CSS3中设置3D变形的transform-style属性详解
2016/05/23 HTML / CSS
html5使用canvas画三角形
2014/12/15 HTML / CSS
HTML5+lufylegend实现游戏中的卷轴
2016/02/29 HTML / CSS
美国南部最大的家族百货公司:Belk
2017/01/30 全球购物
欧洲最大的球衣网上商店:Kitbag
2017/11/11 全球购物
英国可持续奢侈品包包品牌:Elvis & Kresse
2018/08/05 全球购物
Beauty Expert美国/加拿大:购买奢侈美容产品
2018/12/05 全球购物
澳洲最大的时尚奢侈品电商平台:Cettire
2020/06/15 全球购物
教育学专业实习生的自我鉴定
2013/11/26 职场文书
2015年个人招商工作总结
2015/04/25 职场文书
【海涛七七解说】DCG第二周:DK VS 天禄
2022/04/01 DOTA