TensorFlow中tf.batch_matmul()的用法


Posted in Python onJune 02, 2021

TensorFlow中tf.batch_matmul()用法

如果有两个三阶张量,size分别为

a.shape = [100, 3, 4]
b.shape = [100, 4, 5]
c = tf.batch_matmul(a, b)

则c.shape = [100, 3, 5] //将每一对 3x4 的矩阵与 4x5 的矩阵分别相乘。batch_size不变

100为张量的batch_size。剩下的两个维度为数据的维度。

不过新版的tensorflow已经移除了上面的函数,使用时换为tf.matmul就可以了。与上面注释的方式是同样的。

附: 如果是更高维度。例如(a, b, m, n) 与(a, b, n, k)之间做matmul运算。则结果的维度为(a, b, m, k)。

TensorFlow如何实现batch_matmul

我们知道,在tensorflow早期版本中有tf.batch_matmul()函数,可以实现多维tensor和低维tensor的直接相乘,这在使用过程中非常便捷。

但是最新版本的tensorflow现在只有tf.matmul()函数可以使用,不过只能实现同维度的tensor相乘, 下面的几种方法可以实现batch matmul的可能。

例如: tensor A(batch_size,m,n), tensor B(n,k),实现batch matmul 使得A * B。

方法1: 利用tf.matmul()

对tensor B 进行增维和扩展

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
B_exp = tf.tile(tf.expand_dims(B,0),[batch_size, 1, 1]) #先进行增维再扩展
C = tf.matmul(A, B_exp)

方法2: 利用tf.reshape()

对tensor A 进行reshape操作,然后利用tf.matmul()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
A = tf.reshape(A, [-1, 3])
C = tf.reshape(tf.matmul(A, B), [-1, 2, 5])

方法3: 利用tf.scan()

利用tf.scan() 对tensor按第0维进行展开的特性

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
initializer = tf.Variable(tf.random_normal(shape=(2,5)))
C = tf.scan(lambda a,x: tf.matmul(x, B), A, initializer)

方法4: 利用tf.einsum()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
C = tf.einsum('ijk,kl->ijl',A,B)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python正则表达式介绍
Aug 06 Python
详解python如何调用C/C++底层库与互相传值
Aug 10 Python
Python微信库:itchat的用法详解
Aug 14 Python
Python 列表理解及使用方法
Oct 27 Python
浅析python参数的知识点
Dec 10 Python
pandas读取csv文件,分隔符参数sep的实例
Dec 12 Python
解决Python3 抓取微信账单信息问题
Jul 19 Python
python3.7 的新特性详解
Jul 25 Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 Python
详解python 支持向量机(SVM)算法
Sep 18 Python
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 Python
Python多线程 Queue 模块常见用法
Jul 04 Python
pytorch 运行一段时间后出现GPU OOM的问题
Jun 02 #Python
python flask开发的简单基金查询工具
python爬取网页版QQ空间,生成各类图表
Python爬虫实战之爬取携程评论
Pytorch DataLoader shuffle验证方式
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
You might like
PHP生成各种常见验证码和Ajax验证过程
2016/01/10 PHP
PHP实现的折半查找算法示例
2017/12/19 PHP
jValidate 基于jQuery的表单验证插件
2009/12/12 Javascript
js 实现无干扰阴影效果 简单好用(附文件下载)
2009/12/27 Javascript
js用Date对象处理时间实现思路及代码
2013/01/31 Javascript
jquery ui对话框实例代码
2013/05/10 Javascript
input禁止键盘及中文输入,但可以点击
2014/02/13 Javascript
js实现背景图片感应鼠标变化的方法
2015/02/28 Javascript
JS实现向表格行添加新单元格的方法
2015/03/30 Javascript
JS 事件绑定、事件监听、事件委托详细介绍
2016/09/28 Javascript
Angularjs 创建可复用组件实例代码
2016/10/09 Javascript
浅谈JavaScript事件绑定的常用方法及其优缺点分析
2016/11/01 Javascript
jquery点击展示与隐藏更多内容
2016/12/03 Javascript
详解jQuery简单的表单应用
2016/12/16 Javascript
微信小程序实现图片自适应(支持多图)
2017/01/25 Javascript
javascript实现table单元格点击展开隐藏效果(实例代码)
2017/04/10 Javascript
ztree简介_动力节点Java学院整理
2017/07/19 Javascript
基于jQuery解决ios10以上版本缩放问题
2017/11/03 jQuery
Vue插件打包与发布的方法示例
2018/08/20 Javascript
详解vue更改头像功能实现
2019/04/28 Javascript
[02:41]DOTA2英雄基础教程 谜团
2013/12/10 DOTA
[03:58]兄弟们,回来开黑了!DOTA2昔日战友招募宣传视频
2016/07/17 DOTA
[01:00:25]NB vs Secret 2018国际邀请赛小组赛BO1 B组加赛 8.19
2018/08/21 DOTA
python2.7使用plotly绘制本地散点图和折线图
2019/04/02 Python
python实现桌面气泡提示功能
2019/07/29 Python
Python3之乱码\xe6\x97\xa0\xe6\xb3\x95处理方式
2020/05/11 Python
CSS3 transforms应用于背景图像的解决方法
2019/04/16 HTML / CSS
汽车销售求职自荐信
2013/10/01 职场文书
秘书岗位职责
2013/11/18 职场文书
大班幼儿评语大全
2014/04/30 职场文书
感恩母亲节演讲稿
2014/05/07 职场文书
简历自我评价模板
2015/03/11 职场文书
升职自我推荐信范文
2015/03/25 职场文书
甲午风云观后感
2015/06/02 职场文书
浅谈PostgreSQL表分区的三种方式
2021/06/29 PostgreSQL
Win11 S Mode版本泄露 正式上线后叫做Windows 11 SE
2021/11/21 数码科技