关于tensorflow softmax函数用法解析


Posted in Python onJune 30, 2020

如下所示:

def softmax(logits, axis=None, name=None, dim=None):
 """Computes softmax activations.
 This function performs the equivalent of
  softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
 Args:
 logits: A non-empty `Tensor`. Must be one of the following types: `half`,
  `float32`, `float64`.
 axis: The dimension softmax would be performed on. The default is -1 which
  indicates the last dimension.
 name: A name for the operation (optional).
 dim: Deprecated alias for `axis`.
 Returns:
 A `Tensor`. Has the same type and shape as `logits`.
 Raises:
 InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
  dimension of `logits`.
 """
 axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
 if axis is None:
 axis = -1
 return _softmax(logits, gen_nn_ops.softmax, axis, name)

softmax函数的返回结果和输入的tensor有相同的shape,既然没有改变tensor的形状,那么softmax究竟对tensor做了什么?

答案就是softmax会以某一个轴的下标为索引,对这一轴上其他维度的值进行 激活 + 归一化处理

一般来说,这个索引轴都是表示类别的那个维度(tf.nn.softmax中默认为axis=-1,也就是最后一个维度)

举例:

def softmax(X, theta = 1.0, axis = None):
 """
 Compute the softmax of each element along an axis of X.
 Parameters
 ----------
 X: ND-Array. Probably should be floats.
 theta (optional): float parameter, used as a multiplier
  prior to exponentiation. Default = 1.0
 axis (optional): axis to compute values along. Default is the
  first non-singleton axis.
 Returns an array the same size as X. The result will sum to 1
 along the specified axis.
 """
 
 # make X at least 2d
 y = np.atleast_2d(X)
 
 # find axis
 if axis is None:
  axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)
 
 # multiply y against the theta parameter,
 y = y * float(theta)
 
 # subtract the max for numerical stability
 y = y - np.expand_dims(np.max(y, axis = axis), axis)
 
 # exponentiate y
 y = np.exp(y)
 
 # take the sum along the specified axis
 ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)
 
 # finally: divide elementwise
 p = y / ax_sum
 
 # flatten if X was 1D
 if len(X.shape) == 1: p = p.flatten()
 
 return p
c = np.random.randn(2,3)
print(c)
# 假设第0维是类别,一共有里两种类别
cc = softmax(c,axis=0)
# 假设最后一维是类别,一共有3种类别
ccc = softmax(c,axis=-1)
print(cc)
print(ccc)

结果:

c:
[[-1.30022268 0.59127472 1.21384177]
 [ 0.1981082 -0.83686108 -1.54785864]]
cc:
[[0.1826746 0.80661068 0.94057075]
 [0.8173254 0.19338932 0.05942925]]
ccc:
[[0.0500392 0.33172426 0.61823654]
 [0.65371718 0.23222472 0.1140581 ]]

可以看到,对axis=0的轴做softmax时,输出结果在axis=0轴上和为1(eg: 0.1826746+0.8173254),同理在axis=1轴上做的话结果的axis=1轴和也为1(eg: 0.0500392+0.33172426+0.61823654)。

这些值是怎么得到的呢?

以cc为例(沿着axis=0做softmax):

关于tensorflow softmax函数用法解析

以ccc为例(沿着axis=1做softmax):

关于tensorflow softmax函数用法解析

知道了计算方法,现在我们再来讨论一下这些值的实际意义:

cc[0,0]实际上表示这样一种概率: P( label = 0 | value = [-1.30022268 0.1981082] = c[*,0] ) = 0.1826746

cc[1,0]实际上表示这样一种概率: P( label = 1 | value = [-1.30022268 0.1981082] = c[*,0] ) = 0.8173254

ccc[0,0]实际上表示这样一种概率: P( label = 0 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.0500392

ccc[0,1]实际上表示这样一种概率: P( label = 1 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.33172426

ccc[0,2]实际上表示这样一种概率: P( label = 2 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.61823654

将他们扩展到更多维的情况:假设c是一个[batch_size , timesteps, categories]的三维tensor

output = tf.nn.softmax(c,axis=-1)

那么 output[1, 2, 3] 则表示 P(label =3 | value = c[1,2] )

以上这篇关于tensorflow softmax函数用法解析就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作ssh实现服务器日志下载的方法
Jun 03 Python
Python编程之列表操作实例详解【创建、使用、更新、删除】
Jul 22 Python
Python实现截取PDF文件中的几页代码实例
Mar 11 Python
python set内置函数的具体使用
Jul 02 Python
python实现beta分布概率密度函数的方法
Jul 08 Python
python实现切割url得到域名、协议、主机名等各个字段的例子
Jul 25 Python
python实现京东订单推送到测试环境,提供便利操作示例
Aug 09 Python
Python requests获取网页常用方法解析
Feb 20 Python
Python 基于FIR实现Hilbert滤波器求信号包络详解
Feb 26 Python
Pycharm中如何关掉python console
Oct 27 Python
用基于python的appium爬取b站直播消费记录
Apr 17 Python
详解pytorch创建tensor函数
Mar 22 Python
基于tensorflow for循环 while循环案例
Jun 30 #Python
解析Tensorflow之MNIST的使用
Jun 30 #Python
Tensorflow tensor 数学运算和逻辑运算方式
Jun 30 #Python
Python requests模块安装及使用教程图解
Jun 30 #Python
在Tensorflow中实现leakyRelu操作详解(高效)
Jun 30 #Python
TensorFlow-gpu和opencv安装详细教程
Jun 30 #Python
tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
Jun 30 #Python
You might like
使用 MySQL 开始 PHP 会话
2006/12/21 PHP
在PHP中养成7个面向对象的好习惯
2010/07/17 PHP
第五章 php数组操作
2011/12/30 PHP
php中get_meta_tags()、CURL与user-agent用法分析
2014/12/16 PHP
JavaScript 异步调用框架 (Part 5 - 链式实现)
2009/08/04 Javascript
JS多物体 任意值 链式 缓冲运动
2012/08/10 Javascript
将json当数据库一样操作的javascript lib
2013/10/28 Javascript
网页防止tab键的使用快速解决方法
2013/11/07 Javascript
Vue 固定头 固定列 点击表头可排序的表格组件
2016/11/25 Javascript
javascript数组去重常用方法实例分析
2017/04/11 Javascript
js实现本地时间同步功能
2017/08/26 Javascript
Angular将填入表单的数据渲染到表格的方法
2017/09/22 Javascript
jquery ajaxfileuplod 上传文件 essyui laoding 效果【防止重复上传文件】
2018/05/26 jQuery
详解vue.js下引入百度地图jsApi的两种方法
2018/07/27 Javascript
原生js封装的ajax方法示例
2018/08/02 Javascript
微信小程序内拖动图片实现移动、放大、旋转的方法
2018/09/04 Javascript
JS实现手写 forEach算法示例
2020/04/29 Javascript
[01:32]DOTA2次级联赛——首支职业女子战队选拔赛全记录
2014/10/23 DOTA
python在windows命令行下输出彩色文字的方法
2015/03/19 Python
python3中的md5加密实例
2018/05/29 Python
python一行sql太长折成多行并且有多个参数的方法
2018/07/19 Python
Python Learning 列表的更多操作及示例代码
2018/08/22 Python
对python实时得到鼠标位置的示例讲解
2018/10/14 Python
在pycharm中使用matplotlib.pyplot 绘图时报错的解决
2020/06/01 Python
Python爬虫UA伪装爬取的实例讲解
2021/02/19 Python
详解使用canvas保存网页为pdf文件支持跨域
2018/11/23 HTML / CSS
AmazeUI折叠式卡片布局,整合内容列表、表格组件实现
2020/08/20 HTML / CSS
电子商务专业个人的自我评价
2013/12/19 职场文书
上班早退检讨书
2014/01/09 职场文书
新闻编辑求职信
2014/07/13 职场文书
2014年客服工作总结范文
2014/11/13 职场文书
2016年大学生暑假爱心支教活动策划书
2015/11/26 职场文书
大学生党课心得体会
2016/01/07 职场文书
mysql部分操作
2021/04/05 MySQL
Android 界面一键变灰 深色主题工具类
2022/04/28 Java/Android
Windows Server 2012 R2 磁盘分区教程
2022/04/29 Servers