将tensorflow.Variable中的某些元素取出组成一个新的矩阵示例


Posted in Python onJanuary 04, 2020

在神经网络计算过程中,经常会遇到需要将矩阵中的某些元素取出并且单独进行计算的步骤(例如MLE,Attention等操作)。那么在 tensorflow 的 Variable 类型中如何做到这一点呢?

首先假设 Variable 是一个一维数组 A:

import numpy as np

import tensorflow as tf

a = np.array([1, 2, 3, 4, 5, 6, 7, 8])

A = tf.Variable(a)

我们把我们想取出的元素的索引存到 B 中,如果我们只想取出数组 A 中的某一个元素,则 B 的设定为:

b = np.array([3])

B = tf.placeholder(dtype=tf.int32, shape=[1])

由于我们的索引坐标只有一维,所以 shape=1。

取出元素然后组合成tensor C 的操作如下:

C = tf.gather_nd(A, B)

运行:

init = tf.global_variables_initializer()

with tf.Session() as sess:
  init.run()
  feed_dict = {B: b}
  result = sess.run([C], feed_dict=feed_dict)
  print result

得到:

[4]

如果我们想取出一维数组中的多个元素,则需要把每一个想取出的元素索引都单独放一行:

b = np.array([[3], [2], [5], [0]])

B = tf.placeholder(dtype=tf.int32, shape=[4, 1])

此时由于我们想要从一维数组中索引 4 个数,所以 shape=[4, 1]

再次运行得到:

[4 3 6 1]

////////////////////////////////////////////////////////////////////////////////////华丽丽的分割线

假设 Variable 是一个二维矩阵 A:

a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])

A = tf.Variable(a)

首先我们先取出 A 中的一个元素,需要给定该元素的行列坐标,存到 B 中:

b = np.array([2,3])

B = tf.placeholder(dtype=tf.int32, shape=[2])

注意由于我们输入的索引坐标变成了二维,所以shape也变为2。

取出元素然后组合成tensor C:

C = tf.gather_nd(A, B)

运行:

init = tf.global_variables_initializer()

with tf.Session() as sess:
  init.run()
  feed_dict = {B: b}
  result = sess.run([C], feed_dict=feed_dict)
  print result

得到:

[12]

同样的,如果我们想取出二维矩阵中的多个元素,则需要把每一个想取出的元素的索引都单独放一行:

b = np.array([[2, 3], [1, 0], [2, 2], [0, 1]])

B = tf.placeholder(dtype=tf.int32, shape=[4, 2])

此时由于我们想要从二维矩阵中索引出 4 个数,所以 shape=[4, 2]

再次运行得到:

[12 5 11 2]

////////////////////////////////////////////////////////////////////////////////////华丽丽的分割线

推广到 n 维矩阵中:

假设 A 是 Variable 类型的 n 维矩阵,我们想取出矩阵中的 m 个元素,那么首先每个元素的索引坐标要表示成列表的形式:

index = [x1, x2, x3, ..., xn]

其中 xj 代表该元素在 n 维矩阵中第 j 维的位置。

其次每个坐标要单独占索引矩阵的一行:

index_matrix = [[x11, x12, x13, ..., x1n],

               [x21, x22, x23, ..., x2n],

               [x31, x32, x33, ..., x3n],

               .......................................,

               [xm1, xm2, xm3, ..., xmn]]

最后用 tf.gather_nd() 函数替换即可:

result = tf.gather_nd(A, index_matrix)

////////////////////////////////////////////////////////////////////////////////////华丽丽的分割线

[注] 问题出自:https://stackoverflow.com/questions/44793286/slicing-tensorflow-tensor-with-tensor

以上这篇将tensorflow.Variable中的某些元素取出组成一个新的矩阵示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将MongoDB里的ObjectId转换为时间戳的方法
Mar 13 Python
分析Python编程时利用wxPython来支持多线程的方法
Apr 07 Python
python实现在图片上画特定大小角度矩形框
Oct 24 Python
Python格式化字符串f-string概览(小结)
Jun 18 Python
django 2.2和mysql使用的常见问题
Jul 18 Python
Python 用三行代码提取PDF表格数据
Oct 13 Python
Python如何使用turtle库绘制图形
Feb 26 Python
python实现梯度法 python最速下降法
Mar 24 Python
python实现人机五子棋
Mar 25 Python
pytorch实现查看当前学习率
Jun 24 Python
浅谈anaconda python 版本对应关系
Oct 07 Python
python3 os进行嵌套操作的实例讲解
Nov 19 Python
tensorflow实现tensor中满足某一条件的数值取出组成新的tensor
Jan 04 #Python
对tensorflow中的strides参数使用详解
Jan 04 #Python
tensorflow之获取tensor的shape作为max_pool的ksize实例
Jan 04 #Python
TensorFlow tf.nn.max_pool实现池化操作方式
Jan 04 #Python
TensorFlow tf.nn.conv2d实现卷积的方式
Jan 03 #Python
Python调用钉钉自定义机器人的实现
Jan 03 #Python
pytorch中的上采样以及各种反操作,求逆操作详解
Jan 03 #Python
You might like
PHP如何获取当前主机、域名、网址、路径、端口等参数
2017/06/09 PHP
PHP实现链表的定义与反转功能示例
2018/06/09 PHP
javascript cookie解码函数(兼容ff)
2008/03/17 Javascript
javascript 关闭IE6、IE7
2009/06/01 Javascript
写入cookie的JavaScript代码库 cookieLibrary.js
2009/10/24 Javascript
为javascript添加String.Format方法
2020/08/11 Javascript
jquery内置验证(validate)使用方法示例(表单验证)
2013/12/04 Javascript
JavaScript截取字符串的Slice、Substring、Substr函数详解和比较
2014/03/20 Javascript
js单词形式的运算符
2014/05/06 Javascript
jQuery中clone()方法用法实例
2015/01/16 Javascript
jQuery实现限制textarea文本框输入字符数量的方法
2015/05/28 Javascript
jQuery遮罩层效果实例分析
2016/01/14 Javascript
Angularjs实现带查找筛选功能的select下拉框示例代码
2016/10/04 Javascript
详解nodejs 文本操作模块-fs模块(二)
2016/12/22 NodeJs
详解webpack4升级指南以及从webpack3.x迁移
2018/06/12 Javascript
产制造追溯系统之通过微信小程序实现移动端报表平台
2019/06/03 Javascript
uni-app微信小程序登录并使用vuex存储登录状态的思路详解
2019/11/04 Javascript
node实现mock-plugin中间件的方法
2019/12/25 Javascript
JS面向对象之多选框实现
2020/01/17 Javascript
解决Vue使用bus总线时,第一次路由跳转时数据没成功传递问题
2020/07/28 Javascript
[51:44]2018DOTA2亚洲邀请赛 4.3 突围赛 Optic vs iG 第二场
2018/04/04 DOTA
Python使用matplotlib绘制动画的方法
2015/05/20 Python
使用Python的turtle模块画图的方法
2017/11/15 Python
python 打印直角三角形,等边三角形,菱形,正方形的代码
2017/11/21 Python
浅谈关于Python3中venv虚拟环境
2018/08/01 Python
对Python2与Python3中__bool__方法的差异详解
2018/11/01 Python
Python 获取div标签中的文字实例
2018/12/20 Python
在python带权重的列表中随机取值的方法
2019/01/23 Python
Python快速转换numpy数组中Nan和Inf的方法实例说明
2019/02/21 Python
详解python持久化文件读写
2019/04/06 Python
电子商务网站的创业计划书
2014/01/05 职场文书
小学教师事迹材料
2014/01/13 职场文书
人力管理专业毕业生求职信
2014/02/27 职场文书
入党积极分子自我批评思想汇报
2014/10/10 职场文书
详解nginx.conf 中 root 目录设置问题
2021/04/01 Servers
css实现两栏布局,左侧固定宽,右侧自适应的多种方法
2021/08/07 HTML / CSS