TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法


Posted in Python onApril 19, 2020

在计算loss的时候,最常见的一句话就是tf.nn.softmax_cross_entropy_with_logits,那么它到底是怎么做的呢?

首先明确一点,loss是代价值,也就是我们要最小化的值

tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

除去name参数用以指定该操作的name,与方法有关的一共两个参数:

第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes

第二个参数labels:实际的标签,大小同上

具体的执行流程大概分为两步:

第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率)

softmax的公式是:TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法

至于为什么是用的这个公式?这里不介绍了,涉及到比较多的理论证明

第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵,公式如下:

TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法

其中TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)

TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值

显而易见,预测TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss

注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!

理论讲完了,上代码

import tensorflow as tf
 
#our NN's output
logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax
y=tf.nn.softmax(logits)
#true label
y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
#step2:do cross_entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#do cross_entropy just one step
cross_entropy2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, y_))#dont forget tf.reduce_sum()!!
 
with tf.Session() as sess:
  softmax=sess.run(y)
  c_e = sess.run(cross_entropy)
  c_e2 = sess.run(cross_entropy2)
  print("step1:softmax result=")
  print(softmax)
  print("step2:cross_entropy result=")
  print(c_e)
  print("Function(softmax_cross_entropy_with_logits) result=")
  print(c_e2)

输出结果是:

step1:softmax result=
[[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]]
step2:cross_entropy result=
1.22282
Function(softmax_cross_entropy_with_logits) result=
1.2228

最后大家可以试试e^1/(e^1+e^2+e^3)是不是0.09003057,发现确实一样!!这也证明了我们的输出是符合公式逻辑的

到此这篇关于TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法的文章就介绍到这了,更多相关TensorFlow tf.nn.softmax_cross_entropy_with_logits内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python开发WebService系列教程之REST,web.py,eurasia,Django
Jun 30 Python
深入讲解Python中的迭代器和生成器
Oct 26 Python
Python操作SQLite数据库的方法详解【导入,创建,游标,增删改查等】
Jul 11 Python
Scrapy框架CrawlSpiders的介绍以及使用详解
Nov 29 Python
人脸识别经典算法一 特征脸方法(Eigenface)
Mar 13 Python
Python 循环语句之 while,for语句详解
Apr 23 Python
Python全排列操作实例分析
Jul 24 Python
对pandas中两种数据类型Series和DataFrame的区别详解
Nov 12 Python
利用Python+阿里云实现DDNS动态域名解析的方法
Apr 01 Python
python用Tkinter做自己的中文代码编辑器
Sep 07 Python
Python 必须了解的5种高级特征
Sep 10 Python
Python3+SQLAlchemy+Sqlite3实现ORM教程
Feb 16 Python
tensorflow中tf.reduce_mean函数的使用
Apr 19 #Python
TensorFlow打印输出tensor的值
Apr 19 #Python
numpy库reshape用法详解
Apr 19 #Python
tensorflow常用函数API介绍
Apr 19 #Python
TensorFlow的reshape操作 tf.reshape的实现
Apr 19 #Python
pip安装tensorflow的坑的解决
Apr 19 #Python
查看已安装tensorflow版本的方法示例
Apr 19 #Python
You might like
php 缓存函数代码
2008/08/27 PHP
PHP获取搜索引擎关键字来源的函数(支持百度和谷歌等搜索引擎)
2012/10/03 PHP
PHP实现二维数组按某列进行排序的方法
2016/11/18 PHP
PHP实现求解最长公共子串问题的方法
2017/11/17 PHP
jQuery选择头像并实时显示的代码
2010/06/27 Javascript
用显卡加速,轻松把笔记本打造成取暖器的办法!
2013/04/17 Javascript
如何获取网站icon有哪些可行的方法
2014/06/05 Javascript
百度判断手机终端并自动跳转js代码及使用实例
2014/06/11 Javascript
JS使用oumousemove和oumouseout动态改变图片显示的方法
2015/03/31 Javascript
javascript实现table表格隔行变色的方法
2015/05/13 Javascript
JS+JSP通过img标签调用实现静态页面访问次数统计的方法
2015/12/14 Javascript
jquery对Json的各种遍历方法总结(必看篇)
2016/09/29 Javascript
javascript中apply/call和bind的使用
2017/02/15 Javascript
angular 基于ng-messages的表单验证实例
2017/05/04 Javascript
简单实现js进度条加载效果
2020/03/25 Javascript
详解vue-cli中模拟数据的两种方法
2018/07/03 Javascript
对vuejs的v-for遍历、v-bind动态改变值、v-if进行判断的实例讲解
2018/08/27 Javascript
vue基于element的区间选择组件
2018/09/07 Javascript
使用puppeteer爬取网站并抓出404无效链接
2018/12/20 Javascript
浅析微信小程序自定义日历组件及flex布局最后一行对齐问题
2020/10/29 Javascript
[54:08]LGD女子刀塔学院 DOTA2炼金术士教学
2014/01/09 DOTA
python中安装模块包版本冲突问题的解决
2017/05/02 Python
Pycharm学习教程(7)虚拟机VM的配置教程
2017/05/04 Python
深入浅析Python的类
2018/06/22 Python
python使用pdfminer解析pdf文件的方法示例
2018/12/20 Python
Python 正则表达式爬虫使用案例解析
2019/09/23 Python
Python爬虫库requests获取响应内容、响应状态码、响应头
2020/01/25 Python
Python中pyecharts安装及安装失败的解决方法
2020/02/18 Python
解决jupyter notebook 出现In[*]的问题
2020/04/13 Python
使用python批量修改XML文件中图像的depth值
2020/07/22 Python
环境保护标语
2014/06/20 职场文书
处级领导干部四风问题自我剖析材料
2014/09/29 职场文书
化妆品促销活动总结
2015/05/07 职场文书
MySQL深度分页(千万级数据量如何快速分页)
2021/07/25 MySQL
python实现层次聚类的方法
2021/11/01 Python
大脑的记忆过程在做数据压缩,不同图形也有共同的记忆格式
2022/04/29 数码科技