探秘TensorFlow 和 NumPy 的 Broadcasting 机制


Posted in Python onMarch 13, 2020

在使用Tensorflow的过程中,我们经常遇到数组形状不同的情况,但有时候发现二者还能进行加减乘除的运算,在这背后,其实是Tensorflow的broadcast即广播机制帮了大忙。而Tensorflow中的广播机制其实是效仿的numpy中的广播机制。本篇,我们就来一同研究下numpy和Tensorflow中的广播机制。

1、numpy广播原理

1.1 数组和标量计算时的广播

标量和数组合并时就会发生简单的广播,标量会和数组中的每一个元素进行计算。

举个例子:

arr = np.arange(5)
arr * 4

得到的输出为:

array([ 0,  4,  8, 12, 16])

这个是很好理解的,我们重点来研究数组之间的广播

1.2 数组之间计算时的广播

用书中的话来介绍广播的规则:两个数组之间广播的规则:如果两个数组的后缘维度(即从末尾开始算起的维度)的轴长度相等或其中一方的长度为1,则认为他们是广播兼容的,广播会在缺失和(或)长度为1的维度上进行。

上面的规则挺拗口的,我们举几个例子吧:

二维的情况

假设有一个二维数组,我们想要减去它在0轴和1轴的均值,这时的广播是什么样的呢。

我们先来看减去0轴均值的情况:

arr = np.arange(12).reshape(4,3)
arr-arr.mean(0)

输出的结果为:

array([[-4.5, -4.5, -4.5],
       [-1.5, -1.5, -1.5],
       [ 1.5,  1.5,  1.5],
       [ 4.5,  4.5,  4.5]])

0轴的平均值为[4.5,5.5,6.5],形状为(3,),而原数组形状为(4,3),在进行广播时,从后往前比较两个数组的形状,首先是3=3,满足条件而继续比较,这时候发现其中一个数组的形状数组遍历完成,因此会在缺失轴即0轴上进行广播。

可以理解成将均值数组在0轴上复制4份,变成形状(4,3)的数组,再与原数组进行计算。

书中的图形象的表示了这个过程(数据不一样请忽略):

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来看一下减去1轴平均值的情况,即每行都减去该行的平均值:

arr - arr.mean(1)

此时报错了:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来念叨一遍我们的广播规则,均值数组的形状为(4,),而原数组形状为(4,3),按照比较规则,4 != 3,因此不符合广播的条件,因此报错。

正确的做法是什么呢,因为原数组在0轴上的形状为4,我们的均值数组必须要先有一个值能够跟3比较同时满足我们的广播规则,这个值不用多想,就是1。因此我们需要先将均值数组变成(4,1)的形状,再去进行运算:

arr-arr.mean(1).reshape((4,1))

得到正确的结果:

array([[-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.]])

三维的情况

理解了二维的情况,我们也就能很快的理解三维数组的情况。

首先看下图:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

根据广播原则分析:arr1的shape为(3,4,2),arr2的shape为(4,2),它们的后缘轴长度都为(4,2),所以可以在0轴进行广播。因此,arr2在0轴上复制三份,shape变为(3,4,2),再进行计算。

不只是0轴,1轴和2轴也都可以进行广播。但形状必须满足一定的条件。举个例子来说,我们arr1的shape为(8,5,3),想要在0轴上广播的话,arr2的shape是(1,5,3)或者(5,3),想要在1轴上进行广播的话,arr2的shape是(8,1,3),想要在2轴上广播的话,arr2的shape必须是(8,5,1)。

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们来写几个例子吧:

arr2 = np.arange(24).reshape((2,3,4))
arr3_0 = np.arange(12).reshape((3,4))
print("0轴广播")
print(arr2 - arr3_0)

arr3_1 = np.arange(8).reshape((2,1,4))
print("1轴广播")
print(arr2 - arr3_1)

arr3_2 = np.arange(6).reshape((2,3,1))
print("2轴广播")
print(arr2 - arr3_2)

输出为:

0轴广播
[[[ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  0  0  0]]

 [[12 12 12 12]
  [12 12 12 12]
  [12 12 12 12]]]
1轴广播
[[[ 0  0  0  0]
  [ 4  4  4  4]
  [ 8  8  8  8]]

 [[ 8  8  8  8]
  [12 12 12 12]
  [16 16 16 16]]]
2轴广播
[[[ 0  1  2  3]
  [ 3  4  5  6]
  [ 6  7  8  9]]

 [[ 9 10 11 12]
  [12 13 14 15]
  [15 16
 17 18]]]

如果我们想在两个轴上进行广播,那arr2的shape要满足什么条件呢?

arr1.shape 广播轴 arr2.shape
(8,5,3) 0,1 (3,),(1,3),(1,1,3)
(8,5,3) 0,2 (5,1),(1,5,1)
(8,5,3) 1,2 (8,1,1)

具体的例子就不给出啦,嘻嘻。

2、Tensorflow 广播举例

Tensorflow中的广播机制和numpy是一样的,因此我们给出一些简单的举例:

二维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3),0,0.1))
b = tf.Variable(tf.random_normal((2,1),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[-0.1419442 ,  0.14135399,  0.22752595],
       [ 0.1382471 ,  0.28228047,  0.13102233]], dtype=float32)

三维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,1,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[[-0.0154749 , -0.02047186, -0.01022427, -0.08932371],
        [-0.12693939, -0.08069084, -0.15459496,  0.09405404],
        [ 0.09730847,  0.06936138,  0.04050628,  0.15374713]],

       [[-0.02691782, -0.26384184,  0.05825682, -0.07617196],
        [-0.02653179, -0.01997554, -0.06522765,  0.03028341],
        [-0.07577246,  0.03199019,  0.0321    , -0.12571403]]], dtype=float32)

错误示例

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

ValueError: Dimensions must be equal, but are 3 and 2 for 'sub_2' (op: 'Sub') with input shapes: [2,3,4], [2,4].

到此这篇关于探秘TensorFlow 和 NumPy 的 Broadcasting 机制的文章就介绍到这了,更多相关TensorFlow 和NumPy 的Broadcasting 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python 判断自定义对象类型
Mar 21 Python
使用rst2pdf实现将sphinx生成PDF
Jun 07 Python
Python实现批量检测HTTP服务的状态
Oct 27 Python
对numpy和pandas中数组的合并和拆分详解
Apr 11 Python
Python使用matplotlib和pandas实现的画图操作【经典示例】
Jun 13 Python
python监测当前联网状态并连接的实例
Dec 18 Python
详解如何在Apache中运行Python WSGI应用
Jan 02 Python
Python多线程处理实例详解【单进程/多进程】
Jan 30 Python
Python 写入训练日志文件并控制台输出解析
Aug 13 Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 Python
pandas 操作 Excel操作总结
Mar 31 Python
python数字图像处理:图像的绘制
Jun 28 Python
自定义Django Form中choicefield下拉菜单选取数据库内容实例
Mar 13 #Python
django处理select下拉表单实例(从model到前端到post到form)
Mar 13 #Python
python实现俄罗斯方块游戏(改进版)
Mar 13 #Python
Python之Django自动实现html代码(下拉框,数据选择)
Mar 13 #Python
Tensorflow中的dropout的使用方法
Mar 13 #Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
You might like
PHP网页游戏学习之Xnova(ogame)源码解读(一)
2014/06/23 PHP
Thinkphp搜索时首页分页和搜索页保持条件分页的方法
2014/12/05 PHP
php的无刷新操作实现方法分析
2020/02/28 PHP
javascript 处理HTML元素必须避免使用的一种方法
2009/07/30 Javascript
javascript中"/"运算符常见错误
2010/10/13 Javascript
Javascript 垃圾收集机制介绍理解
2013/05/14 Javascript
javascript抖动元素的小例子
2013/10/28 Javascript
根据身份证号自动输出相关信息(籍贯,出身日期,性别)
2013/11/15 Javascript
详细分析JavaScript函数定义
2015/07/16 Javascript
学习javascript的闭包,原型,和匿名函数之旅
2015/10/18 Javascript
nodeJs爬虫获取数据简单实现代码
2016/03/29 NodeJs
vue2.0实现导航菜单切换效果
2017/05/08 Javascript
axios 封装上传文件的请求方法
2018/09/26 Javascript
使用vue-router切换页面时实现设置过渡动画
2019/10/31 Javascript
JavaScript实现轮播图效果
2020/10/30 Javascript
[02:25]DOTA2英雄基础教程 虚空假面
2014/01/02 DOTA
python生成器generator用法实例分析
2015/06/04 Python
Python中使用bidict模块双向字典结构的奇技淫巧
2016/07/12 Python
python 安装virtualenv和virtualenvwrapper的方法
2017/01/13 Python
详解用Python处理HTML转义字符的5种方式
2017/12/27 Python
Python实现判断给定列表是否有重复元素的方法
2018/04/11 Python
浅析python的优势和不足之处
2018/11/20 Python
python采集微信公众号文章
2018/12/20 Python
用Python将Excel数据导入到SQL Server的例子
2019/08/24 Python
sklearn-SVC实现与类参数详解
2019/12/10 Python
如何用python写个模板引擎
2021/01/14 Python
12个不为大家熟知的HTML5设计小技巧
2016/06/02 HTML / CSS
Yves Rocher捷克官方网站:植物化妆品的创造者
2019/07/31 全球购物
青年志愿者事迹材料
2014/02/07 职场文书
2015年行风建设工作总结
2015/05/15 职场文书
三年级作文之小小梦想
2019/12/06 职场文书
2020年元旦晚会策划书模板
2019/12/30 职场文书
拒绝盗图!教你怎么用python给图片加水印
2021/06/04 Python
sass 常用备忘案例详解
2021/09/15 HTML / CSS
Windows server 2012搭建FTP服务器
2022/04/29 Servers
mysql 体系结构和存储引擎介绍
2022/05/06 MySQL