pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python函数返回多个值的示例方法
Dec 04 Python
使用Python的Tornado框架实现一个一对一聊天的程序
Apr 25 Python
基于Python实现一个简单的银行转账操作
Mar 06 Python
Python数据结构与算法之图的广度优先与深度优先搜索算法示例
Dec 14 Python
浅谈Matplotlib简介和pyplot的简单使用——文本标注和箭头
Jan 09 Python
Python3中内置类型bytes和str用法及byte和string之间各种编码转换 问题
Sep 27 Python
解决pandas.DataFrame.fillna 填充Nan失败的问题
Nov 06 Python
python 多进程共享全局变量之Manager()详解
Aug 15 Python
opencv设置采集视频分辨率方式
Dec 10 Python
Python3变量与基本数据类型用法实例分析
Feb 14 Python
解决Pycharm中恢复被exclude的项目问题(pycharm source root)
Feb 14 Python
端午节将至,用Python爬取粽子数据并可视化,看看网友喜欢哪种粽子吧!
Jun 11 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
php去掉URL网址中带有PHPSESSID的配置方法
2014/07/08 PHP
PHP开启opcache提升代码性能
2015/04/26 PHP
php resizeimage 部分jpg文件 生成缩略图失败的原因分析及解决办法
2016/03/23 PHP
PHP中set_include_path()函数相关用法分析
2016/07/18 PHP
TNC vs RR BO3 第一场 2.14
2021/03/10 DOTA
jQuery学习笔记(3)--用jquery(插件)实现多选项卡功能
2013/04/08 Javascript
Javascript基础教程之JavaScript语法
2015/01/18 Javascript
微信小程序 两种为对象属性赋值的方式详解
2017/02/23 Javascript
基于javascript中的typeof和类型判断(详解)
2017/10/27 Javascript
JS逻辑运算符短路操作实例分析
2018/07/09 Javascript
小程序实现多选框功能
2018/10/30 Javascript
javascript使用substring实现的展开与收缩文字功能示例
2019/06/17 Javascript
详解node.js 事件循环
2020/07/22 Javascript
详解JavaScript中的数据类型,以及检测数据类型的方法
2020/09/17 Javascript
JavaScript/TypeScript 实现并发请求控制的示例代码
2021/01/18 Javascript
python的描述符(descriptor)、装饰器(property)造成的一个无限递归问题分享
2014/07/09 Python
Python解析xml中dom元素的方法
2015/03/12 Python
Python中用memcached来减少数据库查询次数的教程
2015/04/07 Python
使用Python脚本将文字转换为图片的实例分享
2015/08/29 Python
Python中使用bidict模块双向字典结构的奇技淫巧
2016/07/12 Python
python利用paramiko连接远程服务器执行命令的方法
2017/10/16 Python
python三大神器之fabric使用教程
2019/06/10 Python
Python中一些深不见底的“坑”
2019/06/12 Python
基于python爬取有道翻译过程图解
2020/03/31 Python
用python写爬虫简单吗
2020/07/28 Python
Python爬虫获取豆瓣电影并写入excel
2020/07/31 Python
HTML5 Canvas实现图片缩放、翻转、颜色渐变的代码示例
2016/02/28 HTML / CSS
html5的input的required使用中遇到的问题及解决方法
2018/04/24 HTML / CSS
美国知名的时尚购物网站:Anthropologie
2016/12/22 全球购物
中专自荐信
2013/10/13 职场文书
数据管理员的自我评价分享
2013/11/15 职场文书
信息管理员岗位职责
2013/12/01 职场文书
“三支一扶”支教教师思想汇报
2014/09/13 职场文书
2014年行政人事工作总结
2014/12/09 职场文书
2015年度优秀员工自荐书
2015/03/06 职场文书
退休劳动合同怎么写?
2019/10/25 职场文书