详解torch.Tensor的4种乘法


Posted in Python onSeptember 03, 2020

torch.Tensor有4种常见的乘法:*, torch.mul, torch.mm, torch.matmul. 本文抛砖引玉,简单叙述一下这4种乘法的区别,具体使用还是要参照官方文档。

点乘

a与b做*乘法,原则是如果a与b的size不同,则以某种方式将a或b进行复制,使得复制后的a和b的size相同,然后再将a和b做element-wise的乘法

下面以*标量和*一维向量为例展示上述过程。

* 标量

Tensor与标量k做*乘法的结果是Tensor的每个元素乘以k(相当于把k复制成与lhs大小相同,元素全为k的Tensor).

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> a * 2
tensor([[2., 2., 2., 2.],
    [2., 2., 2., 2.],
    [2., 2., 2., 2.]])

* 一维向量

Tensor与行向量做*乘法的结果是每列乘以行向量对应列的值(相当于把行向量的行复制,成为与lhs维度相同的Tensor). 注意此时要求Tensor的列数与行向量的列数相等。

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3,4])
>>> b
tensor([1., 2., 3., 4.])
>>> a * b
tensor([[1., 2., 3., 4.],
    [1., 2., 3., 4.],
    [1., 2., 3., 4.]])

Tensor与列向量做*乘法的结果是每行乘以列向量对应行的值(相当于把列向量的列复制,成为与lhs维度相同的Tensor). 注意此时要求Tensor的行数与列向量的行数相等。

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
    [2.],
    [3.]])
>>> a * b
tensor([[1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]])

* 矩阵

经Arsmart在评论区提醒,增补一个矩阵 * 矩阵的例子,感谢Arsmart的热心评论!
如果两个二维矩阵A与B做点积A * B,则要求A与B的维度完全相同,即A的行数=B的行数,A的列数=B的列数

>>> a = torch.tensor([[1, 2], [2, 3]])
>>> a * a
tensor([[1, 4],
    [4, 9]])

broadcast

点积是broadcast的。broadcast是torch的一个概念,简单理解就是在一定的规则下允许高维Tensor和低维Tensor之间的运算。broadcast的概念稍显复杂,在此不做展开,可以参考官方文档关于broadcast的介绍. 在torch.matmul里会有关于broadcast的应用的一个简单的例子。

这里举一个点积broadcast的例子。在例子中,a是二维Tensor,b是三维Tensor,但是a的维度与b的后两位相同,那么a和b仍然可以做点积,点积结果是一个和b维度一样的三维Tensor,运算规则是:若c = a * b, 则c[i,*,*] = a * b[i, *, *],即沿着b的第0维做二维Tensor点积,或者可以理解为运算前将a沿着b的第0维也进行了expand操作,即a = a.expand(b.size()); a * b

>>> a = torch.tensor([[1, 2], [2, 3]])
>>> b = torch.tensor([[[1,2],[2,3]],[[-1,-2],[-2,-3]]])
>>> a * b
tensor([[[ 1, 4],
     [ 4, 9]],

    [[-1, -4],
     [-4, -9]]])
>>> b * a
tensor([[[ 1, 4],
     [ 4, 9]],

    [[-1, -4],
     [-4, -9]]])

其实,上面提到的二维Tensor点积标量、二维Tensor点积行向量,都是发生在高维向量和低维向量之间的,也可以看作是broadcast.

torch.mul

官方文档关于torch.mul的介绍. 用法与*乘法相同,也是element-wise的乘法,也是支持broadcast的。

下面是几个torch.mul的例子.

乘标量

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> a * 2
tensor([[2., 2., 2., 2.],
    [2., 2., 2., 2.],
    [2., 2., 2., 2.]])

乘行向量

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3,4])
>>> b
tensor([1., 2., 3., 4.])
>>> torch.mul(a, b)
tensor([[1., 2., 3., 4.],
    [1., 2., 3., 4.],
    [1., 2., 3., 4.]])

乘列向量

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
    [2.],
    [3.]])
>>> torch.mul(a, b)
tensor([[1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]])

乘矩阵

例1:二维矩阵 mul 二维矩阵

>>> a = torch.tensor([[1, 2], [2, 3]])
>>> torch.mul(a,a)
tensor([[1, 4],
    [4, 9]])

例2:二维矩阵 mul 三维矩阵(broadcast)

>>> a = torch.tensor([[1, 2], [2, 3]])
>>> b = torch.tensor([[[1,2],[2,3]],[[-1,-2],[-2,-3]]])
>>> torch.mul(a,b)
tensor([[[ 1, 4],
     [ 4, 9]],

    [[-1, -4],
     [-4, -9]]])

torch.mm

官方文档关于torch.mm的介绍. 数学里的矩阵乘法,要求两个Tensor的维度满足矩阵乘法的要求.

例子:

>>> a = torch.ones(3,4)
>>> b = torch.ones(4,2)
>>> torch.mm(a, b)
tensor([[4., 4.],
    [4., 4.],
    [4., 4.]])

torch.matmul

官方文档关于torch.matmul的介绍. torch.mm的broadcast版本.

例子:

>>> a = torch.ones(3,4)
>>> b = torch.ones(5,4,2)
>>> torch.matmul(a, b)
tensor([[[4., 4.],
     [4., 4.],
     [4., 4.]],

    [[4., 4.],
     [4., 4.],
     [4., 4.]],

    [[4., 4.],
     [4., 4.],
     [4., 4.]],

    [[4., 4.],
     [4., 4.],
     [4., 4.]],

    [[4., 4.],
     [4., 4.],
     [4., 4.]]])

同样的a和b,使用torch.mm相乘会报错

>>> torch.mm(a, b)
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
RuntimeError: matrices expected, got 2D, 3D tensors at /pytorch/aten/src/TH/generic/THTensorMath.cpp:2065

到此这篇关于详解torch.Tensor的4种乘法的文章就介绍到这了,更多相关torch.Tensor 乘法内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python排序搜索基本算法之归并排序实例分析
Dec 08 Python
python中yaml配置文件模块的使用详解
Apr 27 Python
Python实现多条件筛选目标数据功能【测试可用】
Jun 13 Python
对python过滤器和lambda函数的用法详解
Jan 21 Python
用scikit-learn和pandas学习线性回归的方法
Jun 21 Python
Django时区详解
Jul 24 Python
Python tkinter和exe打包的方法
Feb 05 Python
python 函数嵌套及多函数共同运行知识点讲解
Mar 03 Python
pycharm使用技巧之自动调整代码格式总结
Nov 04 Python
Python基于opencv的简单图像轮廓形状识别(全网最简单最少代码)
Jan 28 Python
用Python远程登陆服务器的步骤
Apr 16 Python
TensorFlow中tf.batch_matmul()的用法
Jun 02 Python
详解pytorch tensor和ndarray转换相关总结
Sep 03 #Python
python开发入门——列表生成式
Sep 03 #Python
Pytorch之Tensor和Numpy之间的转换的实现方法
Sep 03 #Python
Python 多线程C段扫描、检测 Ping扫描脚本的实现
Sep 03 #Python
Python开发入门——迭代的基本使用
Sep 03 #Python
Python 整行读取文本方法并去掉readlines换行\n操作
Sep 03 #Python
Python多分支if语句的使用
Sep 03 #Python
You might like
在PHP中使用Sockets 从Usenet中获取文件
2008/01/10 PHP
PHP 函数语法介绍一
2009/06/14 PHP
ThinkPHP快速入门实例教程之数据分页
2014/07/01 PHP
jQuery+php简单实现全选删除的方法
2016/11/28 PHP
php获得刚插入数据的id 的几种方法总结
2018/05/31 PHP
详解Laravel5.6 Passport实现Api接口认证
2018/07/27 PHP
TP5多入口设置实例讲解
2020/12/15 PHP
静态页面下用javascript操作ACCESS数据库(读增改删)的代码
2007/05/14 Javascript
JQuery中$之选择器用法介绍
2011/04/05 Javascript
分享14个很酷的jQuery导航菜单插件
2011/04/25 Javascript
一个简单的网站访问JS计数器 刷新1次加1次访问
2012/09/20 Javascript
JS保存和删除cookie操作 判断cookie是否存在
2013/11/13 Javascript
jquery可定制的在线UEditor编辑器
2015/11/17 Javascript
详解javascript实现自定义事件
2016/01/19 Javascript
理解javascript定时器中的setTimeout与setInterval
2016/02/23 Javascript
如何利用Promises编写更优雅的JavaScript代码
2016/05/17 Javascript
Vue+webpack项目配置便于维护的目录结构教程详解
2018/10/14 Javascript
详解如何解决vue开发请求数据跨域的问题(基于浏览器的配置解决)
2018/11/12 Javascript
nodejs简单抓包工具使用详解
2019/08/23 NodeJs
通过C++学习Python
2015/01/20 Python
深入理解Javascript中的this关键字
2015/03/27 Python
在Python的Django框架中实现Hacker News的一些功能
2015/04/17 Python
Windows下Eclipse+PyDev配置Python+PyQt4开发环境
2016/05/17 Python
Python安装lz4-0.10.1遇到的坑
2018/05/20 Python
python实现根据给定坐标点生成多边形mask的例子
2020/02/18 Python
python3:excel操作之读取数据并返回字典 + 写入的案例
2020/09/01 Python
python rsa-oaep加密的示例代码
2020/09/23 Python
python代数式括号有效性检验示例代码
2020/10/04 Python
澳大利亚领先的亚麻品牌:Bed Threads
2019/12/16 全球购物
广州御银科技股份有限公司试卷(C++)
2016/11/04 面试题
申论倡议书范文
2014/05/13 职场文书
会计师事务所实习证明
2014/11/16 职场文书
惹女朋友生气检讨书
2015/05/06 职场文书
2016寒假假期总结
2015/10/10 职场文书
浅谈MySQL之浅入深出页原理
2021/06/23 MySQL
mysql使用 not int 子查询隐含陷阱
2022/04/12 MySQL