Keras中的多分类损失函数用法categorical_crossentropy


Posted in Python onJune 11, 2020

from keras.utils.np_utils import to_categorical

注意:当使用categorical_crossentropy损失函数时,你的标签应为多类模式,例如如果你有10个类别,每一个样本的标签应该是一个10维的向量,该向量在对应有值的索引位置为1其余为0。

可以使用这个方法进行转换:

from keras.utils.np_utils import to_categorical
categorical_labels = to_categorical(int_labels, num_classes=None)

以mnist数据集为例:

from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

...
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X_train, y_train, epochs=100, batch_size=1, verbose=2)

补充知识:Keras中损失函数binary_crossentropy和categorical_crossentropy产生不同结果的分析

问题

在使用keras做对心电信号分类的项目中发现一个问题,这个问题起源于我的一个使用错误:

binary_crossentropy 二进制交叉熵用于二分类问题中,categorical_crossentropy分类交叉熵适用于多分类问题中,我的心电分类是一个多分类问题,但是我起初使用了二进制交叉熵,代码如下所示:

sgd = SGD(lr=0.003, decay=0, momentum=0.7, nesterov=False)
model.compile(loss='categorical_crossentropy',
  optimizer='sgd',metrics=['accuracy'])
model.fit(X_train, Y_train, validation_data=(X_test,Y_test),batch_size=16, epochs=20)
score = model.evaluate(X_test, Y_test, batch_size=16)

注意:我的CNN网络模型在最后输入层正确使用了应该用于多分类问题的softmax激活函数

后来我在另一个残差网络模型中对同类数据进行相同的分类问题中,正确使用了分类交叉熵,令人奇怪的是残差模型的效果远弱于普通卷积神经网络,这一点是不符合常理的,经过多次修改分析终于发现可能是损失函数的问题,因此我使用二进制交叉熵在残差网络中,终于取得了优于普通卷积神经网络的效果。

因此可以断定问题就出在所使用的损失函数身上

原理

本人也只是个只会使用框架的调参侠,对于一些原理也是一知半解,经过了学习才大致明白,将一些原理记录如下:

要搞明白分类熵和二进制交叉熵先要从二者适用的激活函数说起

激活函数

sigmoid, softmax主要用于神经网络输出层的输出。

softmax函数

Keras中的多分类损失函数用法categorical_crossentropy

softmax可以看作是Sigmoid的一般情况,用于多分类问题。

Softmax函数将K维的实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于 (0,1) 之间。常用于多分类问题。

sigmoid函数

Keras中的多分类损失函数用法categorical_crossentropy

Sigmoid 将一个实数映射到 (0,1) 的区间,可以用来做二分类。Sigmoid 在特征相差比较复杂或是相差不是特别大时效果比较好。Sigmoid不适合用在神经网络的中间层,因为对于深层网络,sigmoid 函数反向传播时,很容易就会出现梯度消失的情况(在 sigmoid 接近饱和区时,变换太缓慢,导数趋于 0,这种情况会造成信息丢失),从而无法完成深层网络的训练。所以Sigmoid主要用于对神经网络输出层的激活。

分析

所以说多分类问题是要softmax激活函数配合分类交叉熵函数使用,而二分类问题要使用sigmoid激活函数配合二进制交叉熵函数适用,但是如果在多分类问题中使用了二进制交叉熵函数最后的模型分类效果会虚高,即比模型本身真实的分类效果好。

所以就会出现我遇到的情况,这里引用了论坛一位大佬的样例:

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # WRONG way

model.fit(x_train, y_train,
   batch_size=batch_size,
   epochs=2, # only 2 epochs, for demonstration purposes
   verbose=1,
   validation_data=(x_test, y_test))

# Keras reported accuracy:
score = model.evaluate(x_test, y_test, verbose=0) 
score[1]
# 0.9975801164627075

# Actual accuracy calculated manually:
import numpy as np
y_pred = model.predict(x_test)
acc = sum([np.argmax(y_test[i])==np.argmax(y_pred[i]) for i in range(10000)])/10000
acc
# 0.98780000000000001

score[1]==acc
# False

样例中模型在评估中得到的准确度高于实际测算得到的准确度,网上给出的原因是Keras没有定义一个准确的度量,但有几个不同的,比如binary_accuracy和categorical_accuracy,当你使用binary_crossentropy时keras默认在评估过程中使用了binary_accuracy,但是针对你的分类要求,应当采用的是categorical_accuracy,所以就造成了这个问题(其中的具体原理我也没去看源码详细了解)

解决

所以问题最后的解决方法就是:

对于多分类问题,要么采用

from keras.metrics import categorical_accuracy
model.compile(loss='binary_crossentropy', 
 optimizer='adam', metrics=[categorical_accuracy])

要么采用

model.compile(loss='categorical_crossentropy',
optimizer='adam',metrics=['accuracy'])

以上这篇Keras中的多分类损失函数用法categorical_crossentropy就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现爬虫下载漫画示例
Feb 16 Python
python类和继承用法实例
Jul 07 Python
全面解析Python的While循环语句的使用方法
Oct 13 Python
Python 爬虫图片简单实现
Jun 01 Python
Python中表示字符串的三种方法
Sep 06 Python
Python实现矩阵相乘的三种方法小结
Jul 26 Python
python 使用sys.stdin和fileinput读入标准输入的方法
Oct 17 Python
python3 反射的四种基本方法解析
Aug 26 Python
如何在 Django 模板中输出 "{{"
Jan 24 Python
python3 简单实现组合设计模式
Jul 02 Python
为什么说python更适合树莓派编程
Jul 20 Python
Python try except finally资源回收的实现
Jan 25 Python
Python 列表中的修改、添加和删除元素的实现
Jun 11 #Python
python中什么是面向对象
Jun 11 #Python
python实现凯撒密码、凯撒加解密算法
Jun 11 #Python
python新手学习可变和不可变对象
Jun 11 #Python
基于Keras 循环训练模型跑数据时内存泄漏的解决方式
Jun 11 #Python
什么是python的id函数
Jun 11 #Python
Keras:Unet网络实现多类语义分割方式
Jun 11 #Python
You might like
PHP命令Command模式用法实例分析
2018/08/08 PHP
实例介绍PHP删除数组中的重复元素
2019/03/03 PHP
PHP+redis实现微博的推模型案例分析
2019/07/10 PHP
jQuery jqgrid 对含特殊字符json 数据的 Java 处理方法
2011/01/01 Javascript
JS 精确统计网站访问量的实例代码
2013/07/05 Javascript
js解析json读取List中的实体对象示例
2014/03/11 Javascript
运用jQuery定时器的原理实现banner图片切换
2014/10/22 Javascript
JavaScript实现的字符串replaceAll函数代码分享
2015/04/02 Javascript
AngularJS开发教程之控制器之间的通信方法分析
2016/12/25 Javascript
jQuery回调方法使用示例
2017/06/26 jQuery
vue使用better-scroll实现下拉刷新、上拉加载
2018/11/23 Javascript
Java Varargs 可变参数用法详解
2020/01/28 Javascript
微信小程序如何加载数据库真实数据的实现
2020/03/04 Javascript
[44:51]2018DOTA2亚洲邀请赛 4.4 淘汰赛 VP vs Liquid 第二场
2018/04/05 DOTA
[01:11:46]DOTA2-DPC中国联赛 正赛 iG vs Magma BO3 第一场 2月23日
2021/03/11 DOTA
Python时间戳与时间字符串互相转换实例代码
2013/11/28 Python
python常规方法实现数组的全排列
2015/03/17 Python
如何准确判断请求是搜索引擎爬虫(蜘蛛)发出的请求
2015/10/13 Python
浅谈编码,解码,乱码的问题
2016/12/30 Python
Python 数据结构之堆栈实例代码
2017/01/22 Python
利用Tkinter和matplotlib两种方式画饼状图的实例
2017/11/06 Python
使用Python+Splinter自动刷新抢12306火车票
2018/01/03 Python
matlab中实现矩阵删除一行或一列的方法
2018/04/04 Python
对python中raw_input()和input()的用法详解
2018/04/22 Python
python语音识别实践之百度语音API
2018/08/30 Python
python爬虫获取小区经纬度以及结构化地址
2018/12/30 Python
pyqt5实现俄罗斯方块游戏
2019/01/11 Python
python实现上传文件到linux指定目录的方法
2020/01/03 Python
在脚本中单独使用django的ORM模型详解
2020/04/01 Python
python使用布隆过滤器的实现示例
2020/08/20 Python
有机婴儿毛毯和衣服:Monica + Andy
2020/03/01 全球购物
外企财务年会演讲稿
2014/01/03 职场文书
结婚当天新郎保证书
2015/05/08 职场文书
单位计划生育责任书
2015/05/09 职场文书
六一儿童节致辞稿(3篇)
2019/07/11 职场文书
MySQL中distinct和count(*)的使用方法比较
2021/05/26 MySQL