TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法


Posted in Python onApril 19, 2020

在计算loss的时候,最常见的一句话就是tf.nn.softmax_cross_entropy_with_logits,那么它到底是怎么做的呢?

首先明确一点,loss是代价值,也就是我们要最小化的值

tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

除去name参数用以指定该操作的name,与方法有关的一共两个参数:

第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes

第二个参数labels:实际的标签,大小同上

具体的执行流程大概分为两步:

第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率)

softmax的公式是:TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法

至于为什么是用的这个公式?这里不介绍了,涉及到比较多的理论证明

第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵,公式如下:

TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法

其中TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)

TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值

显而易见,预测TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss

注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!

理论讲完了,上代码

import tensorflow as tf
 
#our NN's output
logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax
y=tf.nn.softmax(logits)
#true label
y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
#step2:do cross_entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#do cross_entropy just one step
cross_entropy2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, y_))#dont forget tf.reduce_sum()!!
 
with tf.Session() as sess:
  softmax=sess.run(y)
  c_e = sess.run(cross_entropy)
  c_e2 = sess.run(cross_entropy2)
  print("step1:softmax result=")
  print(softmax)
  print("step2:cross_entropy result=")
  print(c_e)
  print("Function(softmax_cross_entropy_with_logits) result=")
  print(c_e2)

输出结果是:

step1:softmax result=
[[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]]
step2:cross_entropy result=
1.22282
Function(softmax_cross_entropy_with_logits) result=
1.2228

最后大家可以试试e^1/(e^1+e^2+e^3)是不是0.09003057,发现确实一样!!这也证明了我们的输出是符合公式逻辑的

到此这篇关于TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法的文章就介绍到这了,更多相关TensorFlow tf.nn.softmax_cross_entropy_with_logits内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python使用百度API上传文件到百度网盘代码分享
Nov 08 Python
用Python代码来绘制彭罗斯点阵的教程
Apr 03 Python
ubuntu系统下 python链接mysql数据库的方法
Jan 09 Python
Python内存管理方式和垃圾回收算法解析
Nov 11 Python
Python GUI Tkinter简单实现个性签名设计
Jun 19 Python
启动Atom并运行python文件的步骤
Nov 09 Python
pygame游戏之旅 python和pygame安装教程
Nov 20 Python
在Tensorflow中查看权重的实现
Jan 24 Python
pycharm无法安装第三方库的问题及解决方法以scrapy为例(图解)
May 09 Python
什么是python类属性
Jun 10 Python
Python实现一个论文下载器的过程
Jan 18 Python
Python绘制分类图的方法
Apr 20 Python
tensorflow中tf.reduce_mean函数的使用
Apr 19 #Python
TensorFlow打印输出tensor的值
Apr 19 #Python
numpy库reshape用法详解
Apr 19 #Python
tensorflow常用函数API介绍
Apr 19 #Python
TensorFlow的reshape操作 tf.reshape的实现
Apr 19 #Python
pip安装tensorflow的坑的解决
Apr 19 #Python
查看已安装tensorflow版本的方法示例
Apr 19 #Python
You might like
新版PHP将向Java靠拢
2006/10/09 PHP
PHP sprintf()函数用例解析
2011/05/18 PHP
PHP测试程序运行时间的类
2012/02/05 PHP
PHP调用VC编写的COM组件实例
2014/03/29 PHP
php 中phar包的使用教程详解
2018/10/26 PHP
URL编码转换,escape() encodeURI() encodeURIComponent()
2006/12/27 Javascript
jQuery 工具函数学习资料
2010/04/29 Javascript
JavaScript判断数组重复内容的两种方法(推荐)
2016/06/06 Javascript
基于css3新属性transform及原生js实现鼠标拖动3d立方体旋转
2016/06/12 Javascript
jqPlot jQuery绘图插件的使用
2016/06/18 Javascript
浅谈JQ中mouseover和mouseenter的区别
2016/09/13 Javascript
JavaScript条件判断_动力节点Java学院整理
2017/06/26 Javascript
React教程之Props验证的具体用法(Props Validation)
2017/09/04 Javascript
js阻止默认右键的下拉菜单方法
2018/01/02 Javascript
vue-router之nuxt动态路由设置的两种方法小结
2018/09/26 Javascript
面试题:react和vue的区别分析
2019/04/08 Javascript
详解Vue 全局变量,局部变量
2019/04/17 Javascript
JavaScript DOM常用操作代码汇总
2020/07/03 Javascript
[02:58]献给西雅图的情书_高清
2014/05/29 DOTA
在Python中使用mongoengine操作MongoDB教程
2015/04/24 Python
python基础知识小结之集合
2015/11/25 Python
python下如何查询CS反恐精英的服务器信息
2017/01/17 Python
python元组和字典的内建函数实例详解
2019/10/22 Python
python线程信号量semaphore使用解析
2019/11/30 Python
python列表推导式入门学习解析
2019/12/02 Python
Python 窗体(tkinter)下拉列表框(Combobox)实例
2020/03/04 Python
澳大利亚男士西服品牌:M.J.Bale
2018/02/06 全球购物
村官学习十八大感想
2014/01/15 职场文书
岗位说明书怎么写
2014/07/30 职场文书
环境保护建议书
2014/08/26 职场文书
教育合作协议范本
2014/10/17 职场文书
学习普通话的体会
2014/11/07 职场文书
现场施工员岗位职责
2015/04/11 职场文书
爱国主义教育基地观后感
2015/06/18 职场文书
组织委员竞选稿
2015/11/21 职场文书
Python绘制散乱的点构成的图的方法
2022/04/21 Python