探秘TensorFlow 和 NumPy 的 Broadcasting 机制


Posted in Python onMarch 13, 2020

在使用Tensorflow的过程中,我们经常遇到数组形状不同的情况,但有时候发现二者还能进行加减乘除的运算,在这背后,其实是Tensorflow的broadcast即广播机制帮了大忙。而Tensorflow中的广播机制其实是效仿的numpy中的广播机制。本篇,我们就来一同研究下numpy和Tensorflow中的广播机制。

1、numpy广播原理

1.1 数组和标量计算时的广播

标量和数组合并时就会发生简单的广播,标量会和数组中的每一个元素进行计算。

举个例子:

arr = np.arange(5)
arr * 4

得到的输出为:

array([ 0,  4,  8, 12, 16])

这个是很好理解的,我们重点来研究数组之间的广播

1.2 数组之间计算时的广播

用书中的话来介绍广播的规则:两个数组之间广播的规则:如果两个数组的后缘维度(即从末尾开始算起的维度)的轴长度相等或其中一方的长度为1,则认为他们是广播兼容的,广播会在缺失和(或)长度为1的维度上进行。

上面的规则挺拗口的,我们举几个例子吧:

二维的情况

假设有一个二维数组,我们想要减去它在0轴和1轴的均值,这时的广播是什么样的呢。

我们先来看减去0轴均值的情况:

arr = np.arange(12).reshape(4,3)
arr-arr.mean(0)

输出的结果为:

array([[-4.5, -4.5, -4.5],
       [-1.5, -1.5, -1.5],
       [ 1.5,  1.5,  1.5],
       [ 4.5,  4.5,  4.5]])

0轴的平均值为[4.5,5.5,6.5],形状为(3,),而原数组形状为(4,3),在进行广播时,从后往前比较两个数组的形状,首先是3=3,满足条件而继续比较,这时候发现其中一个数组的形状数组遍历完成,因此会在缺失轴即0轴上进行广播。

可以理解成将均值数组在0轴上复制4份,变成形状(4,3)的数组,再与原数组进行计算。

书中的图形象的表示了这个过程(数据不一样请忽略):

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来看一下减去1轴平均值的情况,即每行都减去该行的平均值:

arr - arr.mean(1)

此时报错了:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来念叨一遍我们的广播规则,均值数组的形状为(4,),而原数组形状为(4,3),按照比较规则,4 != 3,因此不符合广播的条件,因此报错。

正确的做法是什么呢,因为原数组在0轴上的形状为4,我们的均值数组必须要先有一个值能够跟3比较同时满足我们的广播规则,这个值不用多想,就是1。因此我们需要先将均值数组变成(4,1)的形状,再去进行运算:

arr-arr.mean(1).reshape((4,1))

得到正确的结果:

array([[-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.]])

三维的情况

理解了二维的情况,我们也就能很快的理解三维数组的情况。

首先看下图:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

根据广播原则分析:arr1的shape为(3,4,2),arr2的shape为(4,2),它们的后缘轴长度都为(4,2),所以可以在0轴进行广播。因此,arr2在0轴上复制三份,shape变为(3,4,2),再进行计算。

不只是0轴,1轴和2轴也都可以进行广播。但形状必须满足一定的条件。举个例子来说,我们arr1的shape为(8,5,3),想要在0轴上广播的话,arr2的shape是(1,5,3)或者(5,3),想要在1轴上进行广播的话,arr2的shape是(8,1,3),想要在2轴上广播的话,arr2的shape必须是(8,5,1)。

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们来写几个例子吧:

arr2 = np.arange(24).reshape((2,3,4))
arr3_0 = np.arange(12).reshape((3,4))
print("0轴广播")
print(arr2 - arr3_0)

arr3_1 = np.arange(8).reshape((2,1,4))
print("1轴广播")
print(arr2 - arr3_1)

arr3_2 = np.arange(6).reshape((2,3,1))
print("2轴广播")
print(arr2 - arr3_2)

输出为:

0轴广播
[[[ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  0  0  0]]

 [[12 12 12 12]
  [12 12 12 12]
  [12 12 12 12]]]
1轴广播
[[[ 0  0  0  0]
  [ 4  4  4  4]
  [ 8  8  8  8]]

 [[ 8  8  8  8]
  [12 12 12 12]
  [16 16 16 16]]]
2轴广播
[[[ 0  1  2  3]
  [ 3  4  5  6]
  [ 6  7  8  9]]

 [[ 9 10 11 12]
  [12 13 14 15]
  [15 16
 17 18]]]

如果我们想在两个轴上进行广播,那arr2的shape要满足什么条件呢?

arr1.shape 广播轴 arr2.shape
(8,5,3) 0,1 (3,),(1,3),(1,1,3)
(8,5,3) 0,2 (5,1),(1,5,1)
(8,5,3) 1,2 (8,1,1)

具体的例子就不给出啦,嘻嘻。

2、Tensorflow 广播举例

Tensorflow中的广播机制和numpy是一样的,因此我们给出一些简单的举例:

二维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3),0,0.1))
b = tf.Variable(tf.random_normal((2,1),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[-0.1419442 ,  0.14135399,  0.22752595],
       [ 0.1382471 ,  0.28228047,  0.13102233]], dtype=float32)

三维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,1,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[[-0.0154749 , -0.02047186, -0.01022427, -0.08932371],
        [-0.12693939, -0.08069084, -0.15459496,  0.09405404],
        [ 0.09730847,  0.06936138,  0.04050628,  0.15374713]],

       [[-0.02691782, -0.26384184,  0.05825682, -0.07617196],
        [-0.02653179, -0.01997554, -0.06522765,  0.03028341],
        [-0.07577246,  0.03199019,  0.0321    , -0.12571403]]], dtype=float32)

错误示例

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

ValueError: Dimensions must be equal, but are 3 and 2 for 'sub_2' (op: 'Sub') with input shapes: [2,3,4], [2,4].

到此这篇关于探秘TensorFlow 和 NumPy 的 Broadcasting 机制的文章就介绍到这了,更多相关TensorFlow 和NumPy 的Broadcasting 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
详细解读Python的web.py框架下的application.py模块
May 02 Python
浅谈django中的认证与登录
Oct 31 Python
Numpy中stack(),hstack(),vstack()函数用法介绍及实例
Jan 09 Python
python的xpath获取div标签内html内容,实现innerhtml功能的方法
Jan 02 Python
OpenCV+face++实现实时人脸识别解锁功能
Aug 28 Python
python 基于dlib库的人脸检测的实现
Nov 08 Python
python实现程序重启和系统重启方式
Apr 16 Python
Python unittest如何生成HTMLTestRunner模块
Sep 08 Python
浅析Python的命名空间与作用域
Nov 25 Python
python中slice参数过长的处理方法及实例
Dec 15 Python
详解python3类型注释annotations实用案例
Jan 20 Python
利用python实现汉诺塔游戏
Mar 01 Python
自定义Django Form中choicefield下拉菜单选取数据库内容实例
Mar 13 #Python
django处理select下拉表单实例(从model到前端到post到form)
Mar 13 #Python
python实现俄罗斯方块游戏(改进版)
Mar 13 #Python
Python之Django自动实现html代码(下拉框,数据选择)
Mar 13 #Python
Tensorflow中的dropout的使用方法
Mar 13 #Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
You might like
牡丹941资料
2021/03/01 无线电
回答PHPCHINA上的几个问题:URL映射
2007/02/14 PHP
在字符串指定位置插入一段字符串的php代码
2010/02/16 PHP
php获取qq用户昵称和在线状态(实例分析)
2013/10/27 PHP
PHP数组函数array_multisort()用法实例分析
2016/04/02 PHP
PHP中使用OpenSSL生成证书及加密解密
2017/02/05 PHP
php基于自定义函数记录log日志方法
2017/07/21 PHP
增强的 JavaScript 的 trim 函数的代码
2007/08/13 Javascript
JavaScript入门教程(9) Document文档对象
2009/01/31 Javascript
jquery使用append(content)方法注意事项分享
2014/01/06 Javascript
div失去焦点事件实现思路
2014/04/22 Javascript
require.js深入了解 require.js特性介绍
2014/09/04 Javascript
JS+CSS实现的竖向简洁折叠菜单效果代码
2015/10/22 Javascript
jQuery+AJAX实现遮罩层登录验证界面(附源码)
2020/09/13 Javascript
jquery简单倒计时实现方法
2015/12/18 Javascript
JS实现DIV高度自适应窗口示例
2017/02/16 Javascript
JS如何获取地址栏的参数实例讲解
2018/10/06 Javascript
如何实现iframe父子传参通信
2020/02/05 Javascript
在pycharm中开发vue的方法步骤
2020/03/04 Javascript
python多进程控制学习小结
2018/10/31 Python
python Tkinter版学生管理系统
2019/02/20 Python
PyInstaller将Python文件打包为exe后如何反编译(破解源码)以及防止反编译
2020/04/15 Python
Python内置函数及功能简介汇总
2020/10/13 Python
Opencv python 图片生成视频的方法示例
2020/11/18 Python
CSS3 实现飘动的云朵动画
2020/12/01 HTML / CSS
Html5无刷新修改browser Url的方法
2014/01/15 HTML / CSS
安全的后院和健身蹦床:JumpSport
2019/07/15 全球购物
小学毕业感言50字
2014/02/16 职场文书
槐乡的孩子教学反思
2014/04/27 职场文书
班子成员四风问题自我剖析材料
2014/09/29 职场文书
饭店服务员岗位职责
2015/02/09 职场文书
综合办公室主任岗位职责
2015/04/01 职场文书
2015年幼儿园班主任工作总结
2015/05/12 职场文书
学风建设主题班会
2015/08/17 职场文书
党章党规党纪学习心得体会
2016/01/14 职场文书
redis 限制内存使用大小的实现
2021/05/08 Redis