TensorFlow中tf.batch_matmul()的用法


Posted in Python onJune 02, 2021

TensorFlow中tf.batch_matmul()用法

如果有两个三阶张量,size分别为

a.shape = [100, 3, 4]
b.shape = [100, 4, 5]
c = tf.batch_matmul(a, b)

则c.shape = [100, 3, 5] //将每一对 3x4 的矩阵与 4x5 的矩阵分别相乘。batch_size不变

100为张量的batch_size。剩下的两个维度为数据的维度。

不过新版的tensorflow已经移除了上面的函数,使用时换为tf.matmul就可以了。与上面注释的方式是同样的。

附: 如果是更高维度。例如(a, b, m, n) 与(a, b, n, k)之间做matmul运算。则结果的维度为(a, b, m, k)。

TensorFlow如何实现batch_matmul

我们知道,在tensorflow早期版本中有tf.batch_matmul()函数,可以实现多维tensor和低维tensor的直接相乘,这在使用过程中非常便捷。

但是最新版本的tensorflow现在只有tf.matmul()函数可以使用,不过只能实现同维度的tensor相乘, 下面的几种方法可以实现batch matmul的可能。

例如: tensor A(batch_size,m,n), tensor B(n,k),实现batch matmul 使得A * B。

方法1: 利用tf.matmul()

对tensor B 进行增维和扩展

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
B_exp = tf.tile(tf.expand_dims(B,0),[batch_size, 1, 1]) #先进行增维再扩展
C = tf.matmul(A, B_exp)

方法2: 利用tf.reshape()

对tensor A 进行reshape操作,然后利用tf.matmul()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
A = tf.reshape(A, [-1, 3])
C = tf.reshape(tf.matmul(A, B), [-1, 2, 5])

方法3: 利用tf.scan()

利用tf.scan() 对tensor按第0维进行展开的特性

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
initializer = tf.Variable(tf.random_normal(shape=(2,5)))
C = tf.scan(lambda a,x: tf.matmul(x, B), A, initializer)

方法4: 利用tf.einsum()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
C = tf.einsum('ijk,kl->ijl',A,B)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用rabbitmq实现网络爬虫示例
Feb 20 Python
python网络编程之TCP通信实例和socketserver框架使用例子
Apr 25 Python
Python素数检测实例分析
Jun 15 Python
Python解惑之整数比较详解
Apr 24 Python
基于python(urlparse)模板的使用方法总结
Oct 13 Python
[原创]教女朋友学Python(一)运行环境搭建
Nov 29 Python
python奇偶行分开存储实现代码
Mar 19 Python
python处理DICOM并计算三维模型体积
Feb 26 Python
Pyqt清空某一个QTreeewidgetItem下的所有分支方法
Jun 17 Python
python GUI库图形界面开发之PyQt5信号与槽事件处理机制详细介绍与实例解析
Mar 08 Python
Python机器学习之KNN近邻算法
May 14 Python
Python自动化测试PO模型封装过程详解
Jun 22 Python
pytorch 运行一段时间后出现GPU OOM的问题
Jun 02 #Python
python flask开发的简单基金查询工具
python爬取网页版QQ空间,生成各类图表
Python爬虫实战之爬取携程评论
Pytorch DataLoader shuffle验证方式
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
You might like
让PHP支持断点续传的源码
2010/05/16 PHP
php数据库密码的找回的步骤
2011/01/12 PHP
初品cakephp 入门基础
2012/02/16 PHP
PDO版本问题 Invalid parameter number: no parameters were bound
2013/01/06 PHP
phpQuery占用内存过多的处理方法
2013/11/13 PHP
php实现给图片加灰色半透明效果的方法
2014/10/20 PHP
基于laravel where的高级使用方法
2019/10/10 PHP
浅析Js中的单引号与双引号问题
2013/11/06 Javascript
jquery淡化版banner异步图片文字效果切换图片特效
2014/04/08 Javascript
Jquery仿IGoogle实现可拖动窗口示例代码
2014/08/22 Javascript
Javascript冒泡排序算法详解
2014/12/03 Javascript
浅析javascript操作 cookie对象
2014/12/26 Javascript
JavaScript多线程详解
2015/08/12 Javascript
BootStrap下拉框在firefox浏览器界面不友好的解决方案
2016/08/18 Javascript
canvas仿iwatch时钟效果
2017/03/06 Javascript
Node.js操作redis实现添加查询功能
2017/05/25 Javascript
JS+WCF实现进度条实时监测数据加载量的方法详解
2017/12/19 Javascript
使用vue官方提供的模板vue-cli搭建一个helloWorld案例分析
2018/01/16 Javascript
浅谈VUE单页应用首屏加载速度优化方案
2018/08/28 Javascript
[01:02:55]CHAOS vs Mineski 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/18 DOTA
python使用三角迭代计算圆周率PI的方法
2015/03/20 Python
python选择排序算法实例总结
2015/07/01 Python
Python脚本处理空格的方法
2016/08/08 Python
Python二叉树的遍历操作示例【前序遍历,中序遍历,后序遍历,层序遍历】
2018/12/24 Python
Python3.7 新特性之dataclass装饰器
2019/05/27 Python
基于Python获取城市近7天天气预报
2019/11/26 Python
Python 用__new__方法实现单例的操作
2020/12/11 Python
如何用python爬取微博热搜数据并保存
2021/02/20 Python
CSS的pointer-events属性详细介绍(作用和注意事项)
2014/04/23 HTML / CSS
英国婚礼商城:Wedding Mall
2019/11/02 全球购物
上班早退检讨书
2014/01/09 职场文书
信息科学与技术专业求职信范文
2014/02/20 职场文书
年会搞笑主持词
2014/03/27 职场文书
质量保证书范本
2014/04/29 职场文书
中学生检讨书范文
2014/11/03 职场文书
vue如何清除浏览器历史栈
2022/05/25 Vue.js