TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法


Posted in Python onApril 19, 2020

在计算loss的时候,最常见的一句话就是tf.nn.softmax_cross_entropy_with_logits,那么它到底是怎么做的呢?

首先明确一点,loss是代价值,也就是我们要最小化的值

tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

除去name参数用以指定该操作的name,与方法有关的一共两个参数:

第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes

第二个参数labels:实际的标签,大小同上

具体的执行流程大概分为两步:

第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率)

softmax的公式是:TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法

至于为什么是用的这个公式?这里不介绍了,涉及到比较多的理论证明

第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵,公式如下:

TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法

其中TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)

TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值

显而易见,预测TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss

注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!

理论讲完了,上代码

import tensorflow as tf
 
#our NN's output
logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax
y=tf.nn.softmax(logits)
#true label
y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
#step2:do cross_entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#do cross_entropy just one step
cross_entropy2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, y_))#dont forget tf.reduce_sum()!!
 
with tf.Session() as sess:
  softmax=sess.run(y)
  c_e = sess.run(cross_entropy)
  c_e2 = sess.run(cross_entropy2)
  print("step1:softmax result=")
  print(softmax)
  print("step2:cross_entropy result=")
  print(c_e)
  print("Function(softmax_cross_entropy_with_logits) result=")
  print(c_e2)

输出结果是:

step1:softmax result=
[[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]]
step2:cross_entropy result=
1.22282
Function(softmax_cross_entropy_with_logits) result=
1.2228

最后大家可以试试e^1/(e^1+e^2+e^3)是不是0.09003057,发现确实一样!!这也证明了我们的输出是符合公式逻辑的

到此这篇关于TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法的文章就介绍到这了,更多相关TensorFlow tf.nn.softmax_cross_entropy_with_logits内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python常用随机数与随机字符串方法实例
Apr 09 Python
使用Python对IP进行转换的一些操作技巧小结
Nov 09 Python
Python利用递归和walk()遍历目录文件的方法示例
Jul 14 Python
Python实现判断字符串中包含某个字符的判断函数示例
Jan 08 Python
django输出html内容的实例
May 27 Python
tensorflow 打印内存中的变量方法
Jul 30 Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 Python
python3 使用Opencv打开USB摄像头,配置1080P分辨率的操作
Dec 11 Python
pytorch载入预训练模型后,实现训练指定层
Jan 06 Python
python3连接mysql获取ansible动态inventory脚本
Jan 19 Python
python烟花效果的代码实例
Feb 25 Python
Python实现一个论文下载器的过程
Jan 18 Python
tensorflow中tf.reduce_mean函数的使用
Apr 19 #Python
TensorFlow打印输出tensor的值
Apr 19 #Python
numpy库reshape用法详解
Apr 19 #Python
tensorflow常用函数API介绍
Apr 19 #Python
TensorFlow的reshape操作 tf.reshape的实现
Apr 19 #Python
pip安装tensorflow的坑的解决
Apr 19 #Python
查看已安装tensorflow版本的方法示例
Apr 19 #Python
You might like
PHP通用检测函数集合
2006/11/25 PHP
php split汉字
2009/06/05 PHP
PHP里的中文变量说明
2011/07/23 PHP
PHP遍历数组的几种方法
2012/03/22 PHP
php统计时间和内存使用情况示例分享
2014/03/13 PHP
php文件包含目录配置open_basedir的使用与性能详解
2017/04/03 PHP
thinkphp框架类库扩展操作示例
2019/11/26 PHP
使用IE的地址栏来辅助调试Web页脚本
2007/03/08 Javascript
JQuery实现table行折叠效果以JSON做数据源
2014/05/26 Javascript
JS判断客服QQ号在线还是离线状态的方法
2015/01/13 Javascript
JS实现当前页居中分页效果的方法
2015/06/18 Javascript
基于JavaScript实现div层跟随滚动条滑动
2016/01/12 Javascript
javascript实现数组去重的多种方法
2016/03/14 Javascript
浅谈Javascript数组(推荐)
2016/05/17 Javascript
微信小程序表单验证错误提示效果
2017/05/19 Javascript
javascript基础进阶_深入剖析执行环境及作用域链
2017/09/05 Javascript
js精确的加减乘除实例
2017/11/14 Javascript
解决在Bootstrap模糊框中使用WebUploader的问题
2018/03/22 Javascript
vue脚手架搭建过程图解
2018/06/06 Javascript
no-vnc和node.js实现web远程桌面的完整步骤
2019/08/11 Javascript
JointJS JavaScript流程图绘制框架解析
2019/08/15 Javascript
jQuery实现消息弹出框效果
2019/12/10 jQuery
[01:32]DOTA2 2015国际邀请赛中国区预选赛第四日战报
2015/05/29 DOTA
Pandas读取MySQL数据到DataFrame的方法
2018/07/25 Python
通过python将大量文件按修改时间分类的方法
2018/10/17 Python
Python selenium的基本使用方法分析
2019/12/21 Python
python中什么是面向对象
2020/06/11 Python
Python unittest生成测试报告过程解析
2020/09/08 Python
selenium+python自动化78-autoit参数化与批量上传功能的实现
2021/03/04 Python
美国马匹用品和马钉购物网站:State Line Tack
2018/08/05 全球购物
Charlotte Tilbury澳大利亚官网:英国美妆品牌
2018/10/05 全球购物
英语一分钟演讲稿
2014/04/29 职场文书
2014年学生会部门工作总结
2014/11/07 职场文书
2015年大学班长个人工作总结
2015/04/24 职场文书
2015年房产销售工作总结范文
2015/05/22 职场文书
初中教务主任竞聘演讲稿(范文)
2019/08/20 职场文书