Python数据相关系数矩阵和热力图轻松实现教程


Posted in Python onJune 16, 2020

对其中的参数进行解释

plt.subplots(figsize=(9, 9))设置画面大小,会使得整个画面等比例放大的

sns.heapmap()这个当然是用来生成热力图的啦

df是DataFrame, pandas的这个类还是很常用的啦~

df.corr()就是得到这个dataframe的相关系数矩阵

把这个矩阵直接丢给sns.heapmap中做参数就好啦

sns.heapmap中annot=True,意思是显式热力图上的数值大小。

sns.heapmap中square=True,意思是将图变成一个正方形,默认是一个矩形

sns.heapmap中cmap="Blues"是一种模式,就是图颜色配置方案啦,我很喜欢这一款的。

sns.heapmap中vmax是显示最大值

import seaborn as sns
import matplotlib.pyplot as plt
def test(df):
 dfData = df.corr()
 plt.subplots(figsize=(9, 9)) # 设置画面大小
 sns.heatmap(dfData, annot=True, vmax=1, square=True, cmap="Blues")
 plt.savefig('./BluesStateRelation.png')
 plt.show()

补充知识:python混淆矩阵(confusion_matrix)FP、FN、TP、TN、ROC,精确率(Precision),召回率(Recall),准确率(Accuracy)详述与实现

一、FP、FN、TP、TN

你这蠢货,是不是又把酸葡萄和葡萄酸弄“混淆“”啦!!!

上面日常情况中的混淆就是:是否把某两件东西或者多件东西给弄混了,迷糊了。

在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能.。混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量。

其中,这个矩阵的一行表示预测类中的实例(可以理解为模型预测输出,predict),另一列表示对该预测结果与标签(Ground Truth)进行判定模型的预测结果是否正确,正确为True,反之为False。

在机器学习中ground truth表示有监督学习的训练集的分类准确性,用于证明或者推翻某个假设。有监督的机器学习会对训练数据打标记,试想一下如果训练标记错误,那么将会对测试数据的预测产生影响,因此这里将那些正确打标记的数据成为ground truth。

此时,就引入FP、FN、TP、TN与精确率(Precision),召回率(Recall),准确率(Accuracy)。

以猫狗二分类为例,假定cat为正例-Positive,dog为负例-Negative;预测正确为True,反之为False。我们就可以得到下面这样一个表示FP、FN、TP、TN的表:

Python数据相关系数矩阵和热力图轻松实现教程

此时如下代码所示,其中scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口,可以用于绘制混淆矩阵

skearn.metrics.confusion_matrix(
 y_true, # array, Gound true (correct) target values
 y_pred, # array, Estimated targets as returned by a classifier
 labels=None, # array, List of labels to index the matrix.
 sample_weight=None # array-like of shape = [n_samples], Optional sample weights
)

完整示例代码如下:

__author__ = "lingjun"
# welcome to attention:小白CV
 
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
sns.set()
 
f, (ax1,ax2) = plt.subplots(figsize = (10, 8),nrows=2)
y_true = ["dog", "dog", "dog", "cat", "cat", "cat", "cat"]
y_pred = ["cat", "cat", "dog", "cat", "cat", "cat", "cat"]
C2= confusion_matrix(y_true, y_pred, labels=["dog", "cat"])
print(C2)
print(C2.ravel())
sns.heatmap(C2,annot=True)
 
ax2.set_title('sns_heatmap_confusion_matrix')
ax2.set_xlabel('Pred')
ax2.set_ylabel('True')
f.savefig('sns_heatmap_confusion_matrix.jpg', bbox_inches='tight')

保存的图像如下所示:

Python数据相关系数矩阵和热力图轻松实现教程

这个时候我们还是不知道skearn.metrics.confusion_matrix做了些什么,这个时候print(C2),打印看下C2究竟里面包含着什么。最终的打印结果如下所示:

[[1 2]
 [0 4]]
[1 2 0 4]

解释下上面这几个数字的意思:

C2= confusion_matrix(y_true, y_pred, labels=["dog", "cat"])中的labels的顺序就分布是0、1,negative和positive

注:labels=[]可加可不加,不加情况下会自动识别,自己定义

cat为1-positive,其中真实值中cat有4个,4个被预测为cat,预测正确T,0个被预测为dog,预测错误F;

dog为0-negative,其中真实值中dog有3个,1个被预测为dog,预测正确T,2个被预测为cat,预测错误F。

所以:TN=1、 FP=2 、FN=0、TP=4。

TN=1:预测为negative狗中1个被预测正确了

FP=2 :预测为positive猫中2个被预测错误了

FN=0:预测为negative狗中0个被预测错误了

TP=4:预测为positive猫中4个被预测正确了

Python数据相关系数矩阵和热力图轻松实现教程

这时候再把上面猫狗预测结果拿来看看,6个被预测为cat,但是只有4个的true是cat,此时就和右侧的红圈对应上了。

y_pred = ["cat", "cat", "dog", "cat", "cat", "cat", "cat"]
y_true = ["dog", "dog", "dog", "cat", "cat", "cat", "cat"]

二、精确率(Precision),召回率(Recall),准确率(Accuracy)

有了上面的这些数值,就可以进行如下的计算工作了

准确率(Accuracy):这三个指标里最直观的就是准确率: 模型判断正确的数据(TP+TN)占总数据的比例

"Accuracy: "+str(round((tp+tn)/(tp+fp+fn+tn), 3))

召回率(Recall): 针对数据集中的所有正例label(TP+FN)而言,模型正确判断出的正例(TP)占数据集中所有正例的比例;FN表示被模型误认为是负例但实际是正例的数据;召回率也叫查全率,以物体检测为例,我们往往把图片中的物体作为正例,此时召回率高代表着模型可以找出图片中更多的物体!

"Recall: "+str(round((tp)/(tp+fn), 3))

精确率(Precision):针对模型判断出的所有正例(TP+FP)而言,其中真正例(TP)占的比例。精确率也叫查准率,还是以物体检测为例,精确率高表示模型检测出的物体中大部分确实是物体,只有少量不是物体的对象被当成物体。

"Precision: "+str(round((tp)/(tp+fp), 3))

还有:

("Sensitivity: "+str(round(tp/(tp+fn+0.01), 3)))
("Specificity: "+str(round(1-(fp/(fp+tn+0.01)), 3)))
("False positive rate: "+str(round(fp/(fp+tn+0.01), 3)))
("Positive predictive value: "+str(round(tp/(tp+fp+0.01), 3)))
("Negative predictive value: "+str(round(tn/(fn+tn+0.01), 3)))

三.绘制ROC曲线,及计算以上评价参数

如下为统计数据:

Python数据相关系数矩阵和热力图轻松实现教程

__author__ = "lingjun"
# E-mail: 1763469890@qq.com
 
from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve, auc
from matplotlib import pyplot as plt
import numpy as np
import torch
import csv
 
def confusion_matrix_roc(GT, PD, experiment, n_class):
 GT = GT.numpy()
 PD = PD.numpy()
 
 y_gt = np.argmax(GT, 1)
 y_gt = np.reshape(y_gt, [-1])
 y_pd = np.argmax(PD, 1)
 y_pd = np.reshape(y_pd, [-1])
 
 # ---- Confusion Matrix and Other Statistic Information ----
 if n_class > 2:
  c_matrix = confusion_matrix(y_gt, y_pd)
  # print("Confussion Matrix:\n", c_matrix)
  list_cfs_mtrx = c_matrix.tolist()
  # print("List", type(list_cfs_mtrx[0]))
 
  path_confusion = r"./records/" + experiment + "/confusion_matrix.txt"
  # np.savetxt(path_confusion, (c_matrix))
  np.savetxt(path_confusion, np.reshape(list_cfs_mtrx, -1), delimiter=',', fmt='%5s')
 
 if n_class == 2:
  list_cfs_mtrx = []
  tn, fp, fn, tp = confusion_matrix(y_gt, y_pd).ravel()
 
  list_cfs_mtrx.append("TN: " + str(tn))
  list_cfs_mtrx.append("FP: " + str(fp))
  list_cfs_mtrx.append("FN: " + str(fn))
  list_cfs_mtrx.append("TP: " + str(tp))
  list_cfs_mtrx.append(" ")
  list_cfs_mtrx.append("Accuracy: " + str(round((tp + tn) / (tp + fp + fn + tn), 3)))
  list_cfs_mtrx.append("Sensitivity: " + str(round(tp / (tp + fn + 0.01), 3)))
  list_cfs_mtrx.append("Specificity: " + str(round(1 - (fp / (fp + tn + 0.01)), 3)))
  list_cfs_mtrx.append("False positive rate: " + str(round(fp / (fp + tn + 0.01), 3)))
  list_cfs_mtrx.append("Positive predictive value: " + str(round(tp / (tp + fp + 0.01), 3)))
  list_cfs_mtrx.append("Negative predictive value: " + str(round(tn / (fn + tn + 0.01), 3)))
 
  path_confusion = r"./records/" + experiment + "/confusion_matrix.txt"
  np.savetxt(path_confusion, np.reshape(list_cfs_mtrx, -1), delimiter=',', fmt='%5s')
 
 # ---- ROC ----
 plt.figure(1)
 plt.figure(figsize=(6, 6))
 
 fpr, tpr, thresholds = roc_curve(GT[:, 1], PD[:, 1])
 roc_auc = auc(fpr, tpr)
 
 plt.plot(fpr, tpr, lw=1, label="ATB vs NotTB, area=%0.3f)" % (roc_auc))
 # plt.plot(thresholds, tpr, lw=1, label='Thr%d area=%0.2f)' % (1, roc_auc))
 # plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck')
 
 plt.xlim([0.00, 1.0])
 plt.ylim([0.00, 1.0])
 plt.xlabel("False Positive Rate")
 plt.ylabel("True Positive Rate")
 plt.title("ROC")
 plt.legend(loc="lower right")
 plt.savefig(r"./records/" + experiment + "/ROC.png")
 print("ok")
 
def inference():
 GT = torch.FloatTensor()
 PD = torch.FloatTensor()
 file = r"Sensitive_rename_inform.csv"
 with open(file, 'r', encoding='UTF-8') as f:
  reader = csv.DictReader(f)
  for row in reader:
   # TODO
   max_patient_score = float(row['ai1'])
   doctor_gt = row['gt2']
 
   print(max_patient_score,doctor_gt)
 
   pd = [[max_patient_score, 1-max_patient_score]]
   output_pd = torch.FloatTensor(pd).to(device)
 
   if doctor_gt == "+":
    target = [[1.0, 0.0]]
   else:
    target = [[0.0, 1.0]]
   target = torch.FloatTensor(target) # 类型转换, 将list转化为tensor, torch.FloatTensor([1,2])
   Target = torch.autograd.Variable(target).long().to(device)
 
   GT = torch.cat((GT, Target.float().cpu()), 0) # 在行上进行堆叠
   PD = torch.cat((PD, output_pd.float().cpu()), 0)
 
 confusion_matrix_roc(GT, PD, "ROC", 2)
 
if __name__ == "__main__":
 inference()

若是表格里面有中文,则记得这里进行修改,否则报错

with open(file, 'r') as f:

以上这篇Python数据相关系数矩阵和热力图轻松实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pyqt和pyside开发图形化界面
Jan 22 Python
详解Python中time()方法的使用的教程
May 22 Python
Python中字典(dict)合并的四种方法总结
Aug 10 Python
简述Python2与Python3的不同点
Jan 21 Python
python对列进行平移变换的方法(shift)
Jan 10 Python
Python两台电脑实现TCP通信的方法示例
May 06 Python
python 实现查找文件并输出满足某一条件的数据项方法
Jun 12 Python
windows10下安装TensorFlow Object Detection API的步骤
Jun 13 Python
Python解析命令行读取参数之argparse模块
Jul 26 Python
Django获取应用下的所有models的例子
Aug 30 Python
在pandas中遍历DataFrame行的实现方法
Oct 23 Python
Python selenium键盘鼠标事件实现过程详解
Jul 28 Python
matplotlib.pyplot.matshow 矩阵可视化实例
Jun 16 #Python
使用python matploblib库绘制准确率,损失率折线图
Jun 16 #Python
为什么称python为胶水语言
Jun 16 #Python
在Keras中利用np.random.shuffle()打乱数据集实例
Jun 15 #Python
Python Socket TCP双端聊天功能实现过程详解
Jun 15 #Python
Python实现验证码识别
Jun 15 #Python
Python Tkinter图形工具使用方法及实例解析
Jun 15 #Python
You might like
叶罗丽:为什么大家对颜冰这对CP非常关心,却对金茉两人十分冷漠
2020/03/17 国漫
解决File size limit exceeded 错误的方法
2013/06/14 PHP
深入分析PHP引用(&)
2014/09/04 PHP
基于php(Thinkphp)+jquery 实现ajax多选反选不选删除数据功能
2017/02/24 PHP
CodeIgniter框架验证码类库文件与用法示例
2017/03/18 PHP
PHP多进程编程实例详解
2017/07/19 PHP
jQuery的Ajax时无响应数据的解决方法
2010/05/25 Javascript
如何使用jQUery获取选中radio对应的值(一句代码)
2013/06/03 Javascript
window.navigate 与 window.location.href 的使用区别介绍
2013/09/21 Javascript
js实现iframe跨页面调用函数的方法
2014/12/13 Javascript
在JS方法中返回多个值的方法汇总
2015/05/20 Javascript
Angularjs中的事件广播 —全面解析$broadcast,$emit,$on
2016/05/17 Javascript
深入浅析AngularJS中的一次性数据绑定 (bindonce)
2017/05/11 Javascript
jQuery开源组件BootstrapValidator使用详解
2017/06/29 jQuery
微信小程序之多文件下载的简单封装示例
2018/01/29 Javascript
详解VUE单页应用骨架屏方案
2019/01/17 Javascript
Node.js中package.json中库的版本号(~和^)
2019/04/02 Javascript
浅谈bootstrap layer.open中end的使用方法
2019/09/12 Javascript
vue使用screenfull插件实现全屏功能
2020/09/17 Javascript
使用AutoJs实现微信抢红包的代码
2020/12/31 Javascript
[42:50]NB vs VP 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
Python程序员鲜为人知但你应该知道的17个问题
2014/06/04 Python
简单介绍Python中利用生成器实现的并发编程
2015/05/04 Python
浅谈python迭代器
2017/11/08 Python
Flask模拟实现CSRF攻击的方法
2018/07/24 Python
Python如何实现动态数组
2019/11/02 Python
python 安装教程之Pycharm安装及配置字体主题,换行,自动更新
2020/03/13 Python
Django与pyecharts结合的实例代码
2020/05/13 Python
python代码中怎么换行
2020/06/17 Python
Pytest测试框架基本使用方法详解
2020/11/25 Python
结合CSS3的新特性来总结垂直居中的实现方法
2016/05/30 HTML / CSS
汤米巴哈马官方网站:Tommy Bahama
2017/05/13 全球购物
韩国乐天网上商城:Lotte iMall
2021/02/03 全球购物
幼师自我鉴定范文
2013/10/01 职场文书
顶碗少年教学反思
2014/02/21 职场文书
我与祖国共奋进演讲稿
2014/09/13 职场文书