pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现多线程采集的2个代码例子
Jul 07 Python
Python中Random和Math模块学习笔记
May 18 Python
python图像常规操作
Nov 11 Python
Python+tkinter使用80行代码实现一个计算器实例
Jan 16 Python
Python实现将doc转化pdf格式文档的方法
Jan 19 Python
python文本数据相似度的度量
Mar 12 Python
python实现隐马尔科夫模型HMM
Mar 25 Python
Python闭包函数定义与用法分析
Jul 20 Python
Django发送邮件功能实例详解
Sep 02 Python
python实现拉普拉斯特征图降维示例
Nov 25 Python
使用Python操作ArangoDB的方法步骤
Feb 02 Python
彻底弄懂Python中的回调函数(callback)
Jun 25 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
Ha0k 0.3 PHP 网页木马修改版
2009/10/11 PHP
PHP5.2下preg_replace函数的问题
2015/05/08 PHP
php把数组值转换成键的方法
2015/07/13 PHP
thinkPHP框架对接支付宝即时到账接口回调操作示例
2016/11/14 PHP
读jQuery之十四 (触发事件核心方法)
2011/08/23 Javascript
js去除浏览器默认底图的方法
2015/06/08 Javascript
详细介绍jQuery.outerWidth() 函数具体用法
2015/07/20 Javascript
javascript学习指南之回调问题
2016/04/23 Javascript
利用Javascript实现BMI计算器
2016/08/16 Javascript
关于JavaScript数组你所不知道的3件事
2016/08/24 Javascript
EditPlus 正则表达式 实战(3)
2016/12/15 Javascript
jQuery ajax动态生成table功能示例
2017/06/14 jQuery
Vue slot用法(小结)
2018/10/22 Javascript
vue动态渲染svg、添加点击事件的实现
2020/03/13 Javascript
vue 在服务器端直接修改请求的接口地址
2020/12/19 Vue.js
[11:01]2014DOTA2西雅图邀请赛 冷冷带你探秘威斯汀
2014/07/08 DOTA
[01:59]翻天覆地,因你而变,7.20版本地图更新速览
2018/11/24 DOTA
Pyhton中防止SQL注入的方法
2015/02/05 Python
pip安装py_zipkin时提示的SSL问题对应
2018/12/29 Python
python创建ArcGIS shape文件的实现
2019/12/06 Python
Python 实现Serial 与STM32J进行串口通讯
2019/12/18 Python
Python SSL证书验证问题解决方案
2020/01/13 Python
python爬虫模块URL管理器模块用法解析
2020/02/03 Python
python 实现数据库中数据添加、查询与更新的示例代码
2020/12/07 Python
俄罗斯的精英皮具:Wittchen
2018/01/29 全球购物
Moda Italia荷兰:意大利男士服装
2019/08/31 全球购物
Miller Harris官网:英国小众香水品牌
2020/09/24 全球购物
革命先烈的英雄事迹材料
2014/02/15 职场文书
行政人事经理职位说明书
2014/03/05 职场文书
《特殊的葬礼》教学反思
2014/04/27 职场文书
奶茶店创业计划书
2014/08/14 职场文书
机关党总支领导班子整改方案
2014/09/20 职场文书
KTV门卫岗位职责
2014/10/09 职场文书
工程部经理岗位职责
2015/02/02 职场文书
寻找最美乡村教师观后感
2015/06/18 职场文书
2019年亲子运动会口号
2019/10/11 职场文书