将tensorflow.Variable中的某些元素取出组成一个新的矩阵示例


Posted in Python onJanuary 04, 2020

在神经网络计算过程中,经常会遇到需要将矩阵中的某些元素取出并且单独进行计算的步骤(例如MLE,Attention等操作)。那么在 tensorflow 的 Variable 类型中如何做到这一点呢?

首先假设 Variable 是一个一维数组 A:

import numpy as np

import tensorflow as tf

a = np.array([1, 2, 3, 4, 5, 6, 7, 8])

A = tf.Variable(a)

我们把我们想取出的元素的索引存到 B 中,如果我们只想取出数组 A 中的某一个元素,则 B 的设定为:

b = np.array([3])

B = tf.placeholder(dtype=tf.int32, shape=[1])

由于我们的索引坐标只有一维,所以 shape=1。

取出元素然后组合成tensor C 的操作如下:

C = tf.gather_nd(A, B)

运行:

init = tf.global_variables_initializer()

with tf.Session() as sess:
  init.run()
  feed_dict = {B: b}
  result = sess.run([C], feed_dict=feed_dict)
  print result

得到:

[4]

如果我们想取出一维数组中的多个元素,则需要把每一个想取出的元素索引都单独放一行:

b = np.array([[3], [2], [5], [0]])

B = tf.placeholder(dtype=tf.int32, shape=[4, 1])

此时由于我们想要从一维数组中索引 4 个数,所以 shape=[4, 1]

再次运行得到:

[4 3 6 1]

////////////////////////////////////////////////////////////////////////////////////华丽丽的分割线

假设 Variable 是一个二维矩阵 A:

a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])

A = tf.Variable(a)

首先我们先取出 A 中的一个元素,需要给定该元素的行列坐标,存到 B 中:

b = np.array([2,3])

B = tf.placeholder(dtype=tf.int32, shape=[2])

注意由于我们输入的索引坐标变成了二维,所以shape也变为2。

取出元素然后组合成tensor C:

C = tf.gather_nd(A, B)

运行:

init = tf.global_variables_initializer()

with tf.Session() as sess:
  init.run()
  feed_dict = {B: b}
  result = sess.run([C], feed_dict=feed_dict)
  print result

得到:

[12]

同样的,如果我们想取出二维矩阵中的多个元素,则需要把每一个想取出的元素的索引都单独放一行:

b = np.array([[2, 3], [1, 0], [2, 2], [0, 1]])

B = tf.placeholder(dtype=tf.int32, shape=[4, 2])

此时由于我们想要从二维矩阵中索引出 4 个数,所以 shape=[4, 2]

再次运行得到:

[12 5 11 2]

////////////////////////////////////////////////////////////////////////////////////华丽丽的分割线

推广到 n 维矩阵中:

假设 A 是 Variable 类型的 n 维矩阵,我们想取出矩阵中的 m 个元素,那么首先每个元素的索引坐标要表示成列表的形式:

index = [x1, x2, x3, ..., xn]

其中 xj 代表该元素在 n 维矩阵中第 j 维的位置。

其次每个坐标要单独占索引矩阵的一行:

index_matrix = [[x11, x12, x13, ..., x1n],

               [x21, x22, x23, ..., x2n],

               [x31, x32, x33, ..., x3n],

               .......................................,

               [xm1, xm2, xm3, ..., xmn]]

最后用 tf.gather_nd() 函数替换即可:

result = tf.gather_nd(A, index_matrix)

////////////////////////////////////////////////////////////////////////////////////华丽丽的分割线

[注] 问题出自:https://stackoverflow.com/questions/44793286/slicing-tensorflow-tensor-with-tensor

以上这篇将tensorflow.Variable中的某些元素取出组成一个新的矩阵示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
C#返回当前系统所有可用驱动器符号的方法
Apr 18 Python
python3读取csv和xlsx文件的实例
Jun 22 Python
Flask框架中request、请求钩子、上下文用法分析
Jul 23 Python
python实现图片二值化及灰度处理方式
Dec 07 Python
Python实现i人事自动打卡的示例代码
Jan 09 Python
pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换
Jan 13 Python
Python图像处理库PIL的ImageEnhance模块使用介绍
Feb 26 Python
Python发送邮件封装实现过程详解
May 09 Python
django 利用Q对象与F对象进行查询的实现
May 15 Python
python能做哪方面的工作
Jun 15 Python
浅谈PyTorch中in-place operation的含义
Jun 27 Python
python 字符串格式化的示例
Sep 21 Python
tensorflow实现tensor中满足某一条件的数值取出组成新的tensor
Jan 04 #Python
对tensorflow中的strides参数使用详解
Jan 04 #Python
tensorflow之获取tensor的shape作为max_pool的ksize实例
Jan 04 #Python
TensorFlow tf.nn.max_pool实现池化操作方式
Jan 04 #Python
TensorFlow tf.nn.conv2d实现卷积的方式
Jan 03 #Python
Python调用钉钉自定义机器人的实现
Jan 03 #Python
pytorch中的上采样以及各种反操作,求逆操作详解
Jan 03 #Python
You might like
php UBB 解析实现代码
2011/11/27 PHP
Web程序工作原理详解
2014/12/25 PHP
smarty模板引擎使用内建函数foreach循环取出所有数组值的方法
2015/01/22 PHP
php使用QueryList轻松采集js动态渲染页面方法
2018/09/11 PHP
PHP7导出Excel报ERR_EMPTY_RESPONSE解决方法
2019/04/16 PHP
PHP rmdir()函数的用法总结
2019/07/02 PHP
laravel-admin 实现给grid的列添加行数序号的方法
2019/10/08 PHP
php中Swoole的热更新实现代码实例
2021/03/04 PHP
JavaScript 入门基础知识 想学习js的朋友可以参考下
2009/12/26 Javascript
对 lightbox JS 图片控件进行了一下改造, 使其他支持复杂的图片说明
2010/03/20 Javascript
js相册效果代码(点击创建即可)
2013/04/16 Javascript
JS文本框默认值处理详解
2013/07/10 Javascript
jquery使用jxl插件导出excel示例
2014/04/14 Javascript
异步安全加载javascript文件的方法
2015/07/21 Javascript
Bootstrap开发实战之第一次接触Bootstrap
2016/06/02 Javascript
全面解析Bootstrap中Carousel轮播的使用方法
2016/06/13 Javascript
JS原型链怎么理解
2016/06/27 Javascript
Highcharts学习之坐标轴
2016/08/02 Javascript
Web打印解决方案之证件套打的实现思路
2016/08/29 Javascript
AngularJS中的缓存使用
2017/01/11 Javascript
JavaScript数据结构之二叉树的查找算法示例
2017/04/13 Javascript
jquery中有哪些api jQuery主要API
2017/11/20 jQuery
详解使用Next.js构建服务端渲染应用
2018/07/10 Javascript
Vue实现远程获取路由与页面刷新导致404错误的解决
2019/01/31 Javascript
基于layui table返回的值的多级嵌套的解决方法
2019/09/19 Javascript
JS访问对象两种方式区别解析
2020/08/29 Javascript
python实现音乐下载器
2018/04/15 Python
Pycharm如何打断点的方法步骤
2019/06/13 Python
Django --Xadmin 判断登录者身份实例
2020/07/03 Python
养生餐厅创业计划书范文
2014/03/26 职场文书
2014年学校总务处工作总结
2014/12/08 职场文书
交通安全月活动总结
2015/05/08 职场文书
党员干部学习三严三实心得体会
2016/01/05 职场文书
小学毕业教师寄语
2019/06/21 职场文书
如何在CSS中绘制曲线图形及展示动画
2021/05/24 HTML / CSS
SpringBoot项目中控制台日志的保存配置操作
2021/06/18 Java/Android