Keras中的多分类损失函数用法categorical_crossentropy


Posted in Python onJune 11, 2020

from keras.utils.np_utils import to_categorical

注意:当使用categorical_crossentropy损失函数时,你的标签应为多类模式,例如如果你有10个类别,每一个样本的标签应该是一个10维的向量,该向量在对应有值的索引位置为1其余为0。

可以使用这个方法进行转换:

from keras.utils.np_utils import to_categorical
categorical_labels = to_categorical(int_labels, num_classes=None)

以mnist数据集为例:

from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

...
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X_train, y_train, epochs=100, batch_size=1, verbose=2)

补充知识:Keras中损失函数binary_crossentropy和categorical_crossentropy产生不同结果的分析

问题

在使用keras做对心电信号分类的项目中发现一个问题,这个问题起源于我的一个使用错误:

binary_crossentropy 二进制交叉熵用于二分类问题中,categorical_crossentropy分类交叉熵适用于多分类问题中,我的心电分类是一个多分类问题,但是我起初使用了二进制交叉熵,代码如下所示:

sgd = SGD(lr=0.003, decay=0, momentum=0.7, nesterov=False)
model.compile(loss='categorical_crossentropy',
  optimizer='sgd',metrics=['accuracy'])
model.fit(X_train, Y_train, validation_data=(X_test,Y_test),batch_size=16, epochs=20)
score = model.evaluate(X_test, Y_test, batch_size=16)

注意:我的CNN网络模型在最后输入层正确使用了应该用于多分类问题的softmax激活函数

后来我在另一个残差网络模型中对同类数据进行相同的分类问题中,正确使用了分类交叉熵,令人奇怪的是残差模型的效果远弱于普通卷积神经网络,这一点是不符合常理的,经过多次修改分析终于发现可能是损失函数的问题,因此我使用二进制交叉熵在残差网络中,终于取得了优于普通卷积神经网络的效果。

因此可以断定问题就出在所使用的损失函数身上

原理

本人也只是个只会使用框架的调参侠,对于一些原理也是一知半解,经过了学习才大致明白,将一些原理记录如下:

要搞明白分类熵和二进制交叉熵先要从二者适用的激活函数说起

激活函数

sigmoid, softmax主要用于神经网络输出层的输出。

softmax函数

Keras中的多分类损失函数用法categorical_crossentropy

softmax可以看作是Sigmoid的一般情况,用于多分类问题。

Softmax函数将K维的实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于 (0,1) 之间。常用于多分类问题。

sigmoid函数

Keras中的多分类损失函数用法categorical_crossentropy

Sigmoid 将一个实数映射到 (0,1) 的区间,可以用来做二分类。Sigmoid 在特征相差比较复杂或是相差不是特别大时效果比较好。Sigmoid不适合用在神经网络的中间层,因为对于深层网络,sigmoid 函数反向传播时,很容易就会出现梯度消失的情况(在 sigmoid 接近饱和区时,变换太缓慢,导数趋于 0,这种情况会造成信息丢失),从而无法完成深层网络的训练。所以Sigmoid主要用于对神经网络输出层的激活。

分析

所以说多分类问题是要softmax激活函数配合分类交叉熵函数使用,而二分类问题要使用sigmoid激活函数配合二进制交叉熵函数适用,但是如果在多分类问题中使用了二进制交叉熵函数最后的模型分类效果会虚高,即比模型本身真实的分类效果好。

所以就会出现我遇到的情况,这里引用了论坛一位大佬的样例:

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # WRONG way

model.fit(x_train, y_train,
   batch_size=batch_size,
   epochs=2, # only 2 epochs, for demonstration purposes
   verbose=1,
   validation_data=(x_test, y_test))

# Keras reported accuracy:
score = model.evaluate(x_test, y_test, verbose=0) 
score[1]
# 0.9975801164627075

# Actual accuracy calculated manually:
import numpy as np
y_pred = model.predict(x_test)
acc = sum([np.argmax(y_test[i])==np.argmax(y_pred[i]) for i in range(10000)])/10000
acc
# 0.98780000000000001

score[1]==acc
# False

样例中模型在评估中得到的准确度高于实际测算得到的准确度,网上给出的原因是Keras没有定义一个准确的度量,但有几个不同的,比如binary_accuracy和categorical_accuracy,当你使用binary_crossentropy时keras默认在评估过程中使用了binary_accuracy,但是针对你的分类要求,应当采用的是categorical_accuracy,所以就造成了这个问题(其中的具体原理我也没去看源码详细了解)

解决

所以问题最后的解决方法就是:

对于多分类问题,要么采用

from keras.metrics import categorical_accuracy
model.compile(loss='binary_crossentropy', 
 optimizer='adam', metrics=[categorical_accuracy])

要么采用

model.compile(loss='categorical_crossentropy',
optimizer='adam',metrics=['accuracy'])

以上这篇Keras中的多分类损失函数用法categorical_crossentropy就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的CURL PycURL使用例子
Jun 01 Python
Python出现segfault错误解决方法
Apr 16 Python
python用reduce和map把字符串转为数字的方法
Dec 19 Python
python解析基于xml格式的日志文件
Feb 25 Python
python3中set(集合)的语法总结分享
Mar 24 Python
详解Python list 与 NumPy.ndarry 切片之间的对比
Jul 24 Python
关于Python中空格字符串处理的技巧总结
Aug 10 Python
Django模板语言 Tags使用详解
Sep 09 Python
python的dict判断key是否存在的方法
Dec 09 Python
Biblibili视频投稿接口分析并以Python实现自动投稿功能
Feb 05 Python
字典算法实现及操作 --python(实用)
Mar 31 Python
django 认证类配置实现
Nov 11 Python
Python 列表中的修改、添加和删除元素的实现
Jun 11 #Python
python中什么是面向对象
Jun 11 #Python
python实现凯撒密码、凯撒加解密算法
Jun 11 #Python
python新手学习可变和不可变对象
Jun 11 #Python
基于Keras 循环训练模型跑数据时内存泄漏的解决方式
Jun 11 #Python
什么是python的id函数
Jun 11 #Python
Keras:Unet网络实现多类语义分割方式
Jun 11 #Python
You might like
详解Laravel设置多态关系模型别名的方式
2019/10/17 PHP
google地图的路线实现代码
2009/08/20 Javascript
利用jQuery的$.event.fix函数统一浏览器event事件处理
2009/12/21 Javascript
仅IE6/7/8中innerHTML返回值忽略英文空格的问题
2011/04/07 Javascript
腾讯UED 漂亮的提示信息效果代码
2011/09/12 Javascript
表头固定(利用jquery实现原理介绍)
2012/11/08 Javascript
JavaScript中prototype为对象添加属性的误区介绍
2013/10/15 Javascript
js设置cookie过期及清除浏览器对应名称的cookie
2013/10/24 Javascript
你不需要jQuery(三) 新AJAX方法fetch()
2016/06/14 Javascript
基于jQuery实现一个marquee无缝滚动的插件
2017/03/09 Javascript
Bootstrap学习笔记之进度条、媒体对象实例详解
2017/03/09 Javascript
原JS实现banner图的常用功能
2017/06/12 Javascript
基于vue-video-player自定义播放器的方法
2018/03/21 Javascript
vue中touch和click共存的解决方式
2020/07/28 Javascript
浅析python中的分片与截断序列
2016/08/09 Python
Python实现将一个正整数分解质因数的方法分析
2017/12/14 Python
Python爬虫获取整个站点中的所有外部链接代码示例
2017/12/26 Python
Tensorflow实现AlexNet卷积神经网络及运算时间评测
2018/05/24 Python
简单了解python的break、continue、pass
2019/07/08 Python
python dumps和loads区别详解
2020/02/04 Python
python3中使用__slots__限定实例属性操作分析
2020/02/14 Python
matplotlib 曲线图 和 折线图 plt.plot()实例
2020/04/17 Python
解决keras,val_categorical_accuracy:,0.0000e+00问题
2020/07/02 Python
Python特殊属性property原理及使用方法解析
2020/10/09 Python
Html5新增标签有哪些
2017/04/13 HTML / CSS
Alba Moda德国网上商店:意大利时尚女装销售
2016/11/14 全球购物
自荐信的五个重要部分
2013/10/29 职场文书
小学教师寄语大全
2014/04/03 职场文书
2014物价局民主生活会对照检查材料思想汇报
2014/09/24 职场文书
国际政治学专业推荐信
2014/09/26 职场文书
2016年基层党组织公开承诺书
2016/03/25 职场文书
go语言中http超时引发的事故解决
2021/06/02 Golang
浅谈JS的原型和原型链
2021/06/04 Javascript
js中Map和Set的用法及区别实例详解
2022/02/15 Javascript
python数字图像处理实现图像的形变与缩放
2022/06/28 Python
vue递归实现树形组件
2022/07/15 Vue.js