pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python的Tornado框架异步编程入门实例
Apr 24 Python
Python实现感知器模型、两层神经网络
Dec 19 Python
Pycharm代码无法复制,无法选中删除,无法编辑的解决方法
Oct 22 Python
Python利用sqlacodegen自动生成ORM实体类示例
Jun 04 Python
PyCharm搭建Spark开发环境实现第一个pyspark程序
Jun 13 Python
python实现数据分析与建模
Jul 11 Python
python Django 创建应用过程图示详解
Jul 29 Python
python 画出使用分类器得到的决策边界
Aug 21 Python
基于Python批量生成指定尺寸缩略图代码实例
Nov 20 Python
Python中Flask-RESTful编写API接口(小白入门)
Dec 11 Python
python中的装饰器该如何使用
Jun 18 Python
python manim实现排序算法动画示例
Aug 14 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
服务器端解压缩zip的脚本
2006/12/22 PHP
最新用php获取谷歌PR值算法,附上php查询PR值代码示例
2011/12/25 PHP
分享一个PHP数据流应用的简单例子
2012/06/01 PHP
php定义数组和使用示例(php数组的定义方法)
2014/03/29 PHP
PHP解码unicode编码的中文字符代码分享
2014/08/13 PHP
dedecms中使用php语句指南
2014/11/13 PHP
Yii2 队列 shmilyzxt/yii2-queue 简单概述
2017/08/02 PHP
php实现的生成迷宫与迷宫寻址算法完整实例
2017/11/06 PHP
php+ajax 文件上传代码实例
2019/03/18 PHP
兼容多浏览器的iframe自适应高度(ie8 、谷歌浏览器4.0和 firefox3.5.3)
2009/11/04 Javascript
一个原生的用户等级的进度条
2010/07/03 Javascript
javascript中最常用的继承模式 组合继承
2010/08/12 Javascript
Jquery Validate 正则表达式实用验证代码大全
2013/08/23 Javascript
JS实现HTML标签转义及反转义
2020/04/14 Javascript
Bootstrap风格的zTree右键菜单
2017/02/17 Javascript
Vue组件之Tooltip的示例代码
2017/10/18 Javascript
Angular搜索 过滤 批量删除 添加 表单验证功能集锦(实例代码)
2017/10/25 Javascript
在vue2.0中引用element-ui组件库的方法
2018/06/21 Javascript
Vue2.0学习系列之项目上线的方法步骤(图文)
2018/09/25 Javascript
原生js实现Flappy Bird小游戏
2018/12/24 Javascript
vue实现输入框自动跳转功能
2020/05/20 Javascript
python魔法方法-属性访问控制详解
2016/07/25 Python
基于Django的ModelForm组件(详解)
2017/12/07 Python
和孩子一起学习python之变量命名规则
2018/05/27 Python
Python数据预处理之数据规范化(归一化)示例
2019/01/08 Python
react+django清除浏览器缓存的几种方法小结
2019/07/17 Python
python爬虫模拟浏览器访问-User-Agent过程解析
2019/12/28 Python
Python3+Appium安装及Appium模拟微信登录方法详解
2021/02/16 Python
CSS3控制HTML元素动画效果
2014/02/08 HTML / CSS
兰蔻加拿大官方网站:Lancome加拿大
2016/08/05 全球购物
毕业生个人求职信范文分享
2014/01/05 职场文书
信息员培训方案
2014/06/12 职场文书
碧霞祠导游词
2015/02/09 职场文书
2015年组织部工作总结
2015/04/03 职场文书
Django Paginator分页器的使用示例
2021/06/23 Python
win sever 2022如何占用操作主机角色
2022/06/25 Servers