利用python中的matplotlib打印混淆矩阵实例


Posted in Python onJune 16, 2020

前面说过混淆矩阵是我们在处理分类问题时,很重要的指标,那么如何更好的把混淆矩阵给打印出来呢,直接做表或者是前端可视化,小编曾经就尝试过用前端(D5)做出来,然后截图,显得不那么好看。。

代码:

import itertools
import matplotlib.pyplot as plt
import numpy as np
 
def plot_confusion_matrix(cm, classes,
       normalize=False,
       title='Confusion matrix',
       cmap=plt.cm.Blues):
 """
 This function prints and plots the confusion matrix.
 Normalization can be applied by setting `normalize=True`.
 """
 if normalize:
  cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  print("Normalized confusion matrix")
 else:
  print('Confusion matrix, without normalization')
 
 print(cm)
 
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 
 fmt = '.2f' if normalize else 'd'
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  plt.text(j, i, format(cm[i, j], fmt),
     horizontalalignment="center",
     color="white" if cm[i, j] > thresh else "black")
 
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()
 # plt.savefig('confusion_matrix',dpi=200)
 
cnf_matrix = np.array([
 [4101, 2, 5, 24, 0],
 [50, 3930, 6, 14, 5],
 [29, 3, 3973, 4, 0],
 [45, 7, 1, 3878, 119],
 [31, 1, 8, 28, 3936],
])
 
class_names = ['Buildings', 'Farmland', 'Greenbelt', 'Wasteland', 'Water']
 
# plt.figure()
# plot_confusion_matrix(cnf_matrix, classes=class_names,
#      title='Confusion matrix, without normalization')
 
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
      title='Normalized confusion matrix')

在放矩阵位置,放一下你的混淆矩阵就可以,当然可视化混淆矩阵这一步也可以直接在模型运行中完成。

补充知识:混淆矩阵(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)

原理

在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能. 混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量. 这个矩阵的每一行表示真实类中的实例, 而每一列表示预测类中的实例 (Tensorflow 和 scikit-learn 采用的实现方式). 也可以是, 每一行表示预测类中的实例, 而每一列表示真实类中的实例 (Confusion matrix From Wikipedia 中的定义). 通过混淆矩阵, 可以很容易看出系统是否会弄混两个类, 这也是混淆矩阵名字的由来.

混淆矩阵是一种特殊类型的列联表(contingency table)或交叉制表(cross tabulation or crosstab). 其有两维 (真实值 "actual" 和 预测值 "predicted" ), 这两维都具有相同的类("classes")的集合. 在列联表中, 每个维度和类的组合是一个变量. 列联表以表的形式, 可视化地表示多个变量的频率分布.

使用混淆矩阵( scikit-learn 和 Tensorflow)

下面先介绍在 scikit-learn 和 tensorflow 中计算混淆矩阵的 API (Application Programming Interface) 接口函数, 然后在一个示例中, 使用这两个 API 函数.

scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口

skearn.metrics.confusion_matrix(
 y_true, # array, Gound true (correct) target values
 y_pred, # array, Estimated targets as returned by a classifier
 labels=None, # array, List of labels to index the matrix.
 sample_weight=None # array-like of shape = [n_samples], Optional sample weights
)

在 scikit-learn 中, 计算混淆矩阵用来评估分类的准确度.

按照定义, 混淆矩阵 C 中的元素 Ci,j 等于真实值为组 i , 而预测为组 j 的观测数(the number of observations). 所以对于二分类任务, 预测结果中, 正确的负例数(true negatives, TN)为 C0,0; 错误的负例数(false negatives, FN)为 C1,0; 真实的正例数为 C1,1; 错误的正例数为 C0,1.

如果 labels 为 None, scikit-learn 会把在出现在 y_true 或 y_pred 中的所有值添加到标记列表 labels 中, 并排好序.

Tensorflow 混淆矩阵函数 tf.confusion_matrix API 接口

tf.confusion_matrix(
 labels, # 1-D Tensor of real labels for the classification task
 predictions, # 1-D Tensor of predictions for a givenclassification
 num_classes=None, # The possible number of labels the classification task can have
 dtype=tf.int32, # Data type of the confusion matrix 
 name=None, # Scope name
 weights=None, # An optional Tensor whose shape matches predictions
)

Tensorflow tf.confusion_matrix 中的 num_classes 参数的含义, 与 scikit-learn sklearn.metrics.confusion_matrix 中的 labels 参数相近, 是与标记有关的参数, 表示类的总个数, 但没有列出具体的标记值. 在 Tensorflow 中一般是以整数作为标记, 如果标记为字符串等非整数类型, 则需先转为整数表示. 如果 num_classes 参数为 None, 则把 labels 和 predictions 中的最大值 + 1, 作为num_classes 参数值.

tf.confusion_matrix 的 weights 参数和 sklearn.metrics.confusion_matrix 的 sample_weight 参数的含义相同, 都是对预测值进行加权, 在此基础上, 计算混淆矩阵单元的值.

使用示例

#!/usr/bin/env python
# -*- coding: utf8 -*-
"""
Author: klchang
Description: 
A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix.
Date: 2018.9.8
"""
from __future__ import print_function
import tensorflow as tf
import sklearn.metrics
 
y_true = [1, 2, 4]
y_pred = [2, 2, 4]
 
# Build graph with tf.confusion_matrix operation
sess = tf.InteractiveSession()
op = tf.confusion_matrix(y_true, y_pred)
op2 = tf.confusion_matrix(y_true, y_pred, num_classes=6, dtype=tf.float32, weights=tf.constant([0.3, 0.4, 0.3]))
# Execute the graph
print ("confusion matrix in tensorflow: ")
print ("1. default: \n", op.eval())
print ("2. customed: \n", sess.run(op2))
sess.close()
 
# Use sklearn.metrics.confusion_matrix function
print ("\nconfusion matrix in scikit-learn: ")
print ("1. default: \n", sklearn.metrics.confusion_matrix(y_true, y_pred))
print ("2. customed: \n", sklearn.metrics.confusion_matrix(y_true, y_pred, labels=range(6), sample_weight=[0.3, 0.4, 0.3]))

以上这篇利用python中的matplotlib打印混淆矩阵实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之重回函数
Oct 10 Python
C#返回当前系统所有可用驱动器符号的方法
Apr 18 Python
利用python3随机生成中文字符的实现方法
Nov 24 Python
python实现堆和索引堆的代码示例
Mar 19 Python
python递归下载文件夹下所有文件
Aug 31 Python
python队列原理及实现方法示例
Nov 27 Python
Python模块 _winreg操作注册表
Feb 05 Python
Python中的xlrd模块使用原理解析
May 21 Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 Python
python中Array和DataFrame相互转换的实例讲解
Feb 03 Python
Python3.9.1中使用match方法详解
Feb 08 Python
python中if和elif的区别介绍
Nov 07 Python
Python SMTP配置参数并发送邮件
Jun 16 #Python
基于matplotlib中ion()和ioff()的使用详解
Jun 16 #Python
Python数据相关系数矩阵和热力图轻松实现教程
Jun 16 #Python
matplotlib.pyplot.matshow 矩阵可视化实例
Jun 16 #Python
使用python matploblib库绘制准确率,损失率折线图
Jun 16 #Python
为什么称python为胶水语言
Jun 16 #Python
在Keras中利用np.random.shuffle()打乱数据集实例
Jun 15 #Python
You might like
Linux fgetcsv取得的数组元素为空字符串的解决方法
2011/11/25 PHP
深入理解require与require_once与include以及include_once的区别
2013/06/05 PHP
PHP查询并删除数据库多列重复数据的方法(利用数组函数实现)
2016/02/23 PHP
javascript 设计模式之单体模式 面向对象学习基础
2010/04/18 Javascript
nodejs入门详解(多篇文章结合)
2012/03/07 NodeJs
Jquery倒数计时按钮setTimeout的实例代码
2013/07/04 Javascript
jQuery可见性过滤选择器用法示例
2016/09/09 Javascript
分分钟学会vue中vuex的应用(入门教程)
2017/09/14 Javascript
快速解决处理后台返回json数据格式的问题
2018/08/07 Javascript
layui从数据库中获取复选框的值并默认选中方法
2018/08/15 Javascript
node.js调用C++函数的方法示例
2018/09/21 Javascript
vue 重塑数组之修改数组指定index的值操作
2020/08/09 Javascript
Vue+Element UI 树形控件整合下拉功能菜单(tree + dropdown +input)
2020/08/28 Javascript
python实现计算资源图标crc值的方法
2014/10/05 Python
利用Python的Django框架生成PDF文件的教程
2015/07/22 Python
浅谈python迭代器
2017/11/08 Python
python opencv旋转图像(保持图像不被裁减)
2018/07/26 Python
django celery redis使用具体实践
2019/04/08 Python
pyautogui自动化控制鼠标和键盘操作的步骤
2020/04/01 Python
详解使用CSS3的@media来编写响应式的页面
2017/11/01 HTML / CSS
Otel.com:折扣酒店预订
2017/08/24 全球购物
人力资源专业推荐信
2013/11/29 职场文书
领导批评与自我批评范文
2014/10/16 职场文书
医药公司采购员岗位职责
2015/04/03 职场文书
企业党员岗位承诺书
2015/04/27 职场文书
2015年惩防体系建设工作总结
2015/05/22 职场文书
第一节英语课开场白
2015/06/01 职场文书
2015年音乐教学工作总结
2015/07/22 职场文书
幼儿园毕业典礼园长致辞
2015/07/29 职场文书
三十年再续同学情倡议书
2019/11/27 职场文书
MySQL5.7并行复制原理及实现
2021/06/03 MySQL
利用python实时刷新基金估值(摸鱼小工具)
2021/09/15 Python
yyds什么意思?90后已经听不懂00后讲话了……
2022/02/03 杂记
豆瓣2021评分最高动画剧集-豆瓣评分最高的动画剧集2021
2022/03/18 日漫
分享python函数常见关键字
2022/04/26 Python
python中pycryto实现数据加密
2022/04/29 Python