探秘TensorFlow 和 NumPy 的 Broadcasting 机制


Posted in Python onMarch 13, 2020

在使用Tensorflow的过程中,我们经常遇到数组形状不同的情况,但有时候发现二者还能进行加减乘除的运算,在这背后,其实是Tensorflow的broadcast即广播机制帮了大忙。而Tensorflow中的广播机制其实是效仿的numpy中的广播机制。本篇,我们就来一同研究下numpy和Tensorflow中的广播机制。

1、numpy广播原理

1.1 数组和标量计算时的广播

标量和数组合并时就会发生简单的广播,标量会和数组中的每一个元素进行计算。

举个例子:

arr = np.arange(5)
arr * 4

得到的输出为:

array([ 0,  4,  8, 12, 16])

这个是很好理解的,我们重点来研究数组之间的广播

1.2 数组之间计算时的广播

用书中的话来介绍广播的规则:两个数组之间广播的规则:如果两个数组的后缘维度(即从末尾开始算起的维度)的轴长度相等或其中一方的长度为1,则认为他们是广播兼容的,广播会在缺失和(或)长度为1的维度上进行。

上面的规则挺拗口的,我们举几个例子吧:

二维的情况

假设有一个二维数组,我们想要减去它在0轴和1轴的均值,这时的广播是什么样的呢。

我们先来看减去0轴均值的情况:

arr = np.arange(12).reshape(4,3)
arr-arr.mean(0)

输出的结果为:

array([[-4.5, -4.5, -4.5],
       [-1.5, -1.5, -1.5],
       [ 1.5,  1.5,  1.5],
       [ 4.5,  4.5,  4.5]])

0轴的平均值为[4.5,5.5,6.5],形状为(3,),而原数组形状为(4,3),在进行广播时,从后往前比较两个数组的形状,首先是3=3,满足条件而继续比较,这时候发现其中一个数组的形状数组遍历完成,因此会在缺失轴即0轴上进行广播。

可以理解成将均值数组在0轴上复制4份,变成形状(4,3)的数组,再与原数组进行计算。

书中的图形象的表示了这个过程(数据不一样请忽略):

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来看一下减去1轴平均值的情况,即每行都减去该行的平均值:

arr - arr.mean(1)

此时报错了:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来念叨一遍我们的广播规则,均值数组的形状为(4,),而原数组形状为(4,3),按照比较规则,4 != 3,因此不符合广播的条件,因此报错。

正确的做法是什么呢,因为原数组在0轴上的形状为4,我们的均值数组必须要先有一个值能够跟3比较同时满足我们的广播规则,这个值不用多想,就是1。因此我们需要先将均值数组变成(4,1)的形状,再去进行运算:

arr-arr.mean(1).reshape((4,1))

得到正确的结果:

array([[-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.]])

三维的情况

理解了二维的情况,我们也就能很快的理解三维数组的情况。

首先看下图:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

根据广播原则分析:arr1的shape为(3,4,2),arr2的shape为(4,2),它们的后缘轴长度都为(4,2),所以可以在0轴进行广播。因此,arr2在0轴上复制三份,shape变为(3,4,2),再进行计算。

不只是0轴,1轴和2轴也都可以进行广播。但形状必须满足一定的条件。举个例子来说,我们arr1的shape为(8,5,3),想要在0轴上广播的话,arr2的shape是(1,5,3)或者(5,3),想要在1轴上进行广播的话,arr2的shape是(8,1,3),想要在2轴上广播的话,arr2的shape必须是(8,5,1)。

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们来写几个例子吧:

arr2 = np.arange(24).reshape((2,3,4))
arr3_0 = np.arange(12).reshape((3,4))
print("0轴广播")
print(arr2 - arr3_0)

arr3_1 = np.arange(8).reshape((2,1,4))
print("1轴广播")
print(arr2 - arr3_1)

arr3_2 = np.arange(6).reshape((2,3,1))
print("2轴广播")
print(arr2 - arr3_2)

输出为:

0轴广播
[[[ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  0  0  0]]

 [[12 12 12 12]
  [12 12 12 12]
  [12 12 12 12]]]
1轴广播
[[[ 0  0  0  0]
  [ 4  4  4  4]
  [ 8  8  8  8]]

 [[ 8  8  8  8]
  [12 12 12 12]
  [16 16 16 16]]]
2轴广播
[[[ 0  1  2  3]
  [ 3  4  5  6]
  [ 6  7  8  9]]

 [[ 9 10 11 12]
  [12 13 14 15]
  [15 16
 17 18]]]

如果我们想在两个轴上进行广播,那arr2的shape要满足什么条件呢?

arr1.shape 广播轴 arr2.shape
(8,5,3) 0,1 (3,),(1,3),(1,1,3)
(8,5,3) 0,2 (5,1),(1,5,1)
(8,5,3) 1,2 (8,1,1)

具体的例子就不给出啦,嘻嘻。

2、Tensorflow 广播举例

Tensorflow中的广播机制和numpy是一样的,因此我们给出一些简单的举例:

二维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3),0,0.1))
b = tf.Variable(tf.random_normal((2,1),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[-0.1419442 ,  0.14135399,  0.22752595],
       [ 0.1382471 ,  0.28228047,  0.13102233]], dtype=float32)

三维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,1,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[[-0.0154749 , -0.02047186, -0.01022427, -0.08932371],
        [-0.12693939, -0.08069084, -0.15459496,  0.09405404],
        [ 0.09730847,  0.06936138,  0.04050628,  0.15374713]],

       [[-0.02691782, -0.26384184,  0.05825682, -0.07617196],
        [-0.02653179, -0.01997554, -0.06522765,  0.03028341],
        [-0.07577246,  0.03199019,  0.0321    , -0.12571403]]], dtype=float32)

错误示例

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

ValueError: Dimensions must be equal, but are 3 and 2 for 'sub_2' (op: 'Sub') with input shapes: [2,3,4], [2,4].

到此这篇关于探秘TensorFlow 和 NumPy 的 Broadcasting 机制的文章就介绍到这了,更多相关TensorFlow 和NumPy 的Broadcasting 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python实现压缩和解压缩ZIP文件的方法分析
Sep 28 Python
将Django项目部署到CentOs服务器中
Oct 18 Python
python flask web服务实现更换默认端口和IP的方法
Jul 26 Python
Python中的 sort 和 sorted的用法与区别
Aug 10 Python
Python爬取知乎图片代码实现解析
Sep 17 Python
numpy 声明空数组详解
Dec 05 Python
Python魔法方法 容器部方法详解
Jan 02 Python
Django中的模型类设计及展示示例详解
May 29 Python
python 8种必备的gui库
Aug 27 Python
Python SMTP发送电子邮件的示例
Sep 23 Python
Python hashlib模块的使用示例
Oct 09 Python
pycharm配置python 设置pip安装源为豆瓣源
Feb 05 Python
自定义Django Form中choicefield下拉菜单选取数据库内容实例
Mar 13 #Python
django处理select下拉表单实例(从model到前端到post到form)
Mar 13 #Python
python实现俄罗斯方块游戏(改进版)
Mar 13 #Python
Python之Django自动实现html代码(下拉框,数据选择)
Mar 13 #Python
Tensorflow中的dropout的使用方法
Mar 13 #Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
You might like
咖啡是不是喝了会上瘾?咖啡是必须品吗!
2021/03/04 新手入门
使用TinyButStrong模板引擎来做WEB开发
2007/03/16 PHP
PHP 开发环境配置(Zend Studio)
2010/04/28 PHP
PHP 类商品秒杀计时实现代码
2010/05/05 PHP
深入理解用mysql_fetch_row()以数组的形式返回查询结果
2013/06/05 PHP
PHP实现设计模式中的抽象工厂模式详解
2014/10/11 PHP
php中smarty区域循环的方法
2015/06/11 PHP
2017年最新PHP经典面试题目汇总(上篇)
2017/03/17 PHP
豆瓣网的jquery代码实例
2008/06/15 Javascript
JavaScript中null与undefined分析
2009/07/25 Javascript
javascript之typeof、instanceof操作符使用探讨
2013/05/19 Javascript
20行代码实现的一个CSS覆盖率测试脚本
2013/07/07 Javascript
Get中文乱码IE浏览器Get中文乱码解决方案
2013/12/26 Javascript
JavaScript操作HTML元素和样式的方法详解
2015/10/21 Javascript
基于jquery实现表格内容筛选功能实例解析
2016/05/09 Javascript
关于Jquery中的bind(),on()绑定事件方式总结
2016/10/26 Javascript
javascript 实现文本使用省略号替代(超出固定高度的情况)
2017/02/21 Javascript
BootStrap表单验证 FormValidation 调整反馈图标位置的实例代码
2017/05/17 Javascript
jQuery之动画ajax事件(实例讲解)
2017/07/18 jQuery
JS实现字符串中去除指定子字符串方法分析
2018/05/17 Javascript
解决axios发送post请求返回400状态码的问题
2018/08/11 Javascript
[03:40]DOTA2亚洲邀请赛小组赛第二日 赛事回顾
2015/01/31 DOTA
python中利用Future对象回调别的函数示例代码
2017/09/07 Python
python表格存取的方法
2018/03/07 Python
python 读取txt中每行数据,并且保存到excel中的实例
2018/04/29 Python
Python实现的列表排序、反转操作示例
2019/03/13 Python
使用CSS3的::selection改变选中文本颜色的方法
2015/09/29 HTML / CSS
购买一个高级域名:BuyDomains
2018/03/11 全球购物
公共汽车、火车和飞机票的通用在线预订和销售平台:INFOBUS
2019/11/30 全球购物
英国哈罗德园艺:Harrod Horticultural
2020/03/31 全球购物
内部类的定义、种类以及优点
2013/10/16 面试题
总经理司机岗位职责
2014/02/06 职场文书
趣味游戏活动方案
2014/02/07 职场文书
检举信的格式及范文
2014/04/04 职场文书
摩登时代观后感
2015/06/03 职场文书
pytorch 如何把图像数据集进行划分成train,test和val
2021/05/31 Python