TensorFlow中tf.batch_matmul()的用法


Posted in Python onJune 02, 2021

TensorFlow中tf.batch_matmul()用法

如果有两个三阶张量,size分别为

a.shape = [100, 3, 4]
b.shape = [100, 4, 5]
c = tf.batch_matmul(a, b)

则c.shape = [100, 3, 5] //将每一对 3x4 的矩阵与 4x5 的矩阵分别相乘。batch_size不变

100为张量的batch_size。剩下的两个维度为数据的维度。

不过新版的tensorflow已经移除了上面的函数,使用时换为tf.matmul就可以了。与上面注释的方式是同样的。

附: 如果是更高维度。例如(a, b, m, n) 与(a, b, n, k)之间做matmul运算。则结果的维度为(a, b, m, k)。

TensorFlow如何实现batch_matmul

我们知道,在tensorflow早期版本中有tf.batch_matmul()函数,可以实现多维tensor和低维tensor的直接相乘,这在使用过程中非常便捷。

但是最新版本的tensorflow现在只有tf.matmul()函数可以使用,不过只能实现同维度的tensor相乘, 下面的几种方法可以实现batch matmul的可能。

例如: tensor A(batch_size,m,n), tensor B(n,k),实现batch matmul 使得A * B。

方法1: 利用tf.matmul()

对tensor B 进行增维和扩展

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
B_exp = tf.tile(tf.expand_dims(B,0),[batch_size, 1, 1]) #先进行增维再扩展
C = tf.matmul(A, B_exp)

方法2: 利用tf.reshape()

对tensor A 进行reshape操作,然后利用tf.matmul()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
A = tf.reshape(A, [-1, 3])
C = tf.reshape(tf.matmul(A, B), [-1, 2, 5])

方法3: 利用tf.scan()

利用tf.scan() 对tensor按第0维进行展开的特性

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
initializer = tf.Variable(tf.random_normal(shape=(2,5)))
C = tf.scan(lambda a,x: tf.matmul(x, B), A, initializer)

方法4: 利用tf.einsum()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
C = tf.einsum('ijk,kl->ijl',A,B)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python类继承用法实例分析
May 27 Python
快速了解Python相对导入
Jan 12 Python
Python网络爬虫神器PyQuery的基本使用教程
Feb 03 Python
Flask核心机制之上下文源码剖析
Dec 25 Python
python 实现12bit灰度图像映射到8bit显示的方法
Jul 08 Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 Python
Python3 全自动更新已安装的模块实现
Jan 06 Python
python实现学生成绩测评系统
Jun 22 Python
公认8个效率最高的爬虫框架
Jul 28 Python
详解python with 上下文管理器
Sep 02 Python
python 读取串口数据的示例
Nov 09 Python
Pandas的数据过滤实现
Jan 15 Python
pytorch 运行一段时间后出现GPU OOM的问题
Jun 02 #Python
python flask开发的简单基金查询工具
python爬取网页版QQ空间,生成各类图表
Python爬虫实战之爬取携程评论
Pytorch DataLoader shuffle验证方式
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
You might like
一些操作和快捷键的理解和讨论
2020/03/04 星际争霸
5种PHP创建数组的实例代码分享
2014/01/17 PHP
解读PHP中的垃圾回收机制
2015/08/10 PHP
详解PHP5.6.30与Apache2.4.x配置
2017/06/02 PHP
kindeditor 加入七牛云上传的实例讲解
2017/11/12 PHP
PHP实现的分解质因数操作示例
2018/08/01 PHP
php7 图形用户界面GUI 开发示例
2020/02/22 PHP
关于火狐(firefox)及ie下event获取的两种方法
2012/12/27 Javascript
JavaScript调用堆栈及setTimeout使用方法深入剖析
2013/02/16 Javascript
商城常用滚动的焦点图效果代码简单实用
2013/03/28 Javascript
javascript中对变量类型的判断方法
2015/08/09 Javascript
Javascript编程中几种继承方式比较分析
2015/11/28 Javascript
node.js cookie-parser之parser.js
2016/06/06 Javascript
jQuery EasyUI API 中文帮助文档和扩展实例
2016/08/01 Javascript
使用Node.js搭建静态资源服务详细教程
2017/08/02 Javascript
解决html-jquery/js引用外部图片时遇到看不了或出现403的问题
2017/09/22 jQuery
使用webpack搭建react开发环境的方法
2018/05/15 Javascript
Vue keepAlive 数据缓存工具实现返回上一个页面浏览的位置
2019/05/10 Javascript
layer.prompt使文本框为空的情况下也能点击确定的方法
2019/09/24 Javascript
详解如何修改 node_modules 里的文件
2020/05/22 Javascript
[01:05:59]Mineski vs Secret 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.22
2019/09/05 DOTA
在Django的模型中执行原始SQL查询的方法
2015/07/21 Python
Python实现压缩文件夹与解压缩zip文件的方法
2018/09/01 Python
解决python文件双击运行秒退的问题
2019/06/24 Python
Python抓新型冠状病毒肺炎疫情数据并绘制全国疫情分布的代码实例
2020/02/05 Python
Django中modelform组件实例用法总结
2020/02/10 Python
Python列表切片常用操作实例解析
2020/03/10 Python
如何查看Django ORM执行的SQL语句的实现
2020/04/20 Python
PyQt5 QDockWidget控件应用详解
2020/08/12 Python
Python环境使用OpenCV检测人脸实现教程
2020/10/19 Python
结合CSS3的布局新特征谈谈常见布局方法
2016/01/22 HTML / CSS
CK巴西官方网站:Calvin Klein巴西
2019/07/19 全球购物
初中班主任评语
2014/04/24 职场文书
舞蹈教育学专业求职信
2014/06/29 职场文书
亮剑精神观后感
2015/06/05 职场文书
《梅花魂》教学反思
2016/02/18 职场文书