关于tensorflow softmax函数用法解析


Posted in Python onJune 30, 2020

如下所示:

def softmax(logits, axis=None, name=None, dim=None):
 """Computes softmax activations.
 This function performs the equivalent of
  softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
 Args:
 logits: A non-empty `Tensor`. Must be one of the following types: `half`,
  `float32`, `float64`.
 axis: The dimension softmax would be performed on. The default is -1 which
  indicates the last dimension.
 name: A name for the operation (optional).
 dim: Deprecated alias for `axis`.
 Returns:
 A `Tensor`. Has the same type and shape as `logits`.
 Raises:
 InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
  dimension of `logits`.
 """
 axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
 if axis is None:
 axis = -1
 return _softmax(logits, gen_nn_ops.softmax, axis, name)

softmax函数的返回结果和输入的tensor有相同的shape,既然没有改变tensor的形状,那么softmax究竟对tensor做了什么?

答案就是softmax会以某一个轴的下标为索引,对这一轴上其他维度的值进行 激活 + 归一化处理

一般来说,这个索引轴都是表示类别的那个维度(tf.nn.softmax中默认为axis=-1,也就是最后一个维度)

举例:

def softmax(X, theta = 1.0, axis = None):
 """
 Compute the softmax of each element along an axis of X.
 Parameters
 ----------
 X: ND-Array. Probably should be floats.
 theta (optional): float parameter, used as a multiplier
  prior to exponentiation. Default = 1.0
 axis (optional): axis to compute values along. Default is the
  first non-singleton axis.
 Returns an array the same size as X. The result will sum to 1
 along the specified axis.
 """
 
 # make X at least 2d
 y = np.atleast_2d(X)
 
 # find axis
 if axis is None:
  axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)
 
 # multiply y against the theta parameter,
 y = y * float(theta)
 
 # subtract the max for numerical stability
 y = y - np.expand_dims(np.max(y, axis = axis), axis)
 
 # exponentiate y
 y = np.exp(y)
 
 # take the sum along the specified axis
 ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)
 
 # finally: divide elementwise
 p = y / ax_sum
 
 # flatten if X was 1D
 if len(X.shape) == 1: p = p.flatten()
 
 return p
c = np.random.randn(2,3)
print(c)
# 假设第0维是类别,一共有里两种类别
cc = softmax(c,axis=0)
# 假设最后一维是类别,一共有3种类别
ccc = softmax(c,axis=-1)
print(cc)
print(ccc)

结果:

c:
[[-1.30022268 0.59127472 1.21384177]
 [ 0.1981082 -0.83686108 -1.54785864]]
cc:
[[0.1826746 0.80661068 0.94057075]
 [0.8173254 0.19338932 0.05942925]]
ccc:
[[0.0500392 0.33172426 0.61823654]
 [0.65371718 0.23222472 0.1140581 ]]

可以看到,对axis=0的轴做softmax时,输出结果在axis=0轴上和为1(eg: 0.1826746+0.8173254),同理在axis=1轴上做的话结果的axis=1轴和也为1(eg: 0.0500392+0.33172426+0.61823654)。

这些值是怎么得到的呢?

以cc为例(沿着axis=0做softmax):

关于tensorflow softmax函数用法解析

以ccc为例(沿着axis=1做softmax):

关于tensorflow softmax函数用法解析

知道了计算方法,现在我们再来讨论一下这些值的实际意义:

cc[0,0]实际上表示这样一种概率: P( label = 0 | value = [-1.30022268 0.1981082] = c[*,0] ) = 0.1826746

cc[1,0]实际上表示这样一种概率: P( label = 1 | value = [-1.30022268 0.1981082] = c[*,0] ) = 0.8173254

ccc[0,0]实际上表示这样一种概率: P( label = 0 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.0500392

ccc[0,1]实际上表示这样一种概率: P( label = 1 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.33172426

ccc[0,2]实际上表示这样一种概率: P( label = 2 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.61823654

将他们扩展到更多维的情况:假设c是一个[batch_size , timesteps, categories]的三维tensor

output = tf.nn.softmax(c,axis=-1)

那么 output[1, 2, 3] 则表示 P(label =3 | value = c[1,2] )

以上这篇关于tensorflow softmax函数用法解析就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Google开源的Python格式化工具YAPF的安装和使用教程
May 31 Python
Python备份目录及目录下的全部内容的实现方法
Jun 12 Python
浅述python中argsort()函数的实例用法
Mar 30 Python
Linux RedHat下安装Python2.7开发环境
May 20 Python
Python模块搜索路径代码详解
Jan 29 Python
python爬虫框架scrapy实现模拟登录操作示例
Aug 02 Python
Python面向对象基础入门之设置对象属性
Dec 11 Python
pandas按行按列遍历Dataframe的几种方式
Oct 23 Python
django自定义模板标签过程解析
Dec 14 Python
python 等差数列末项计算方式
May 03 Python
python实现自动清理重复文件
Aug 24 Python
python制作图形界面的2048游戏, 基于tkinter
Apr 06 Python
基于tensorflow for循环 while循环案例
Jun 30 #Python
解析Tensorflow之MNIST的使用
Jun 30 #Python
Tensorflow tensor 数学运算和逻辑运算方式
Jun 30 #Python
Python requests模块安装及使用教程图解
Jun 30 #Python
在Tensorflow中实现leakyRelu操作详解(高效)
Jun 30 #Python
TensorFlow-gpu和opencv安装详细教程
Jun 30 #Python
tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
Jun 30 #Python
You might like
ThinkPHP3.1查询语言详解
2014/06/19 PHP
php定义一个参数带有默认值的函数实例分析
2015/03/16 PHP
大家须知简单的php性能优化注意点
2016/01/04 PHP
PHP正则替换函数preg_replace()报错:Notice Use of undefined constant的解决方法分析
2017/02/04 PHP
javascript 获取url参数和script标签中获取url参数函数代码
2010/01/22 Javascript
Javascript实现CheckBox的全选与取消全选的代码
2010/07/20 Javascript
JQuery里面的几种选择器 查找满足条件的元素$("#控件ID")
2011/08/23 Javascript
自定义jQuery选项卡插件实例
2013/03/27 Javascript
JavaScript中伪协议 javascript:使用探讨
2014/07/18 Javascript
修改js confirm alert 提示框文字的简单实例
2016/06/10 Javascript
vue.js实现表格合并示例代码
2016/11/30 Javascript
Bootstrap 网格系统布局详解
2017/03/19 Javascript
JavaScript 基础表单验证示例(纯Js实现)
2017/07/20 Javascript
微信小程序 POST请求的实例详解
2017/09/29 Javascript
基于webpack.config.js 参数详解
2018/03/20 Javascript
详解nodejs通过响应回写的方式渲染页面资源
2018/04/07 NodeJs
create-react-app修改为多页面支持的方法
2018/05/17 Javascript
Vue.js 父子组件通信的十种方式
2018/10/30 Javascript
JS实现随机生成10个手机号的方法示例
2018/12/07 Javascript
vue在自定义组件中使用v-model进行数据绑定的方法
2019/03/25 Javascript
vue实现折线图 可按时间查询
2020/08/21 Javascript
JS检测浏览器开发者工具是否打开的方法详解
2020/10/02 Javascript
vue 解决在微信内置浏览器中调用支付宝支付的情况
2020/11/09 Javascript
[43:51]2018DOTA2亚洲邀请赛3月30日 小组赛B组 EG VS Secret
2018/03/31 DOTA
在SAE上部署Python的Django框架的一些问题汇总
2015/05/30 Python
python用装饰器自动注册Tornado路由详解
2017/02/14 Python
详谈Python高阶函数与函数装饰器(推荐)
2017/09/30 Python
Python实现重建二叉树的三种方法详解
2018/06/23 Python
python输出数组中指定元素的所有索引示例
2019/12/06 Python
HTML5表格_动力节点Java学院整理
2017/07/11 HTML / CSS
GUESS Factory加拿大:牛仔裤、服装及配饰
2019/09/20 全球购物
演讲稿格式
2014/04/30 职场文书
党员“四风”问题批评与自我批评思想汇报
2014/10/06 职场文书
海南召开党的群众路线教育实践活动总结大会新闻稿
2014/10/21 职场文书
大雁塔导游词
2015/02/04 职场文书
教师求职自荐信
2015/03/26 职场文书