pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python搭建虚拟环境的步骤详解
Sep 27 Python
python字符串,数值计算
Oct 05 Python
Python处理CSV与List的转换方法
Apr 19 Python
pygame游戏之旅 添加游戏界面按键图形
Nov 20 Python
Python一行代码实现快速排序的方法
Apr 30 Python
Python redis操作实例分析【连接、管道、发布和订阅等】
May 16 Python
在Pycharm中调试Django项目程序的操作方法
Jul 17 Python
Python3打包exe代码2种方法实例解析
Feb 17 Python
python GUI库图形界面开发之PyQt5计数器控件QSpinBox详细使用方法与实例
Feb 28 Python
Python爬虫基于lxml解决数据编码乱码问题
Jul 31 Python
Python一行代码实现自动发邮件功能
May 30 Python
在pycharm中无法import所安装的库解决方案
May 31 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
YII路径的用法总结
2014/07/09 PHP
PHP实现HTML页面静态化的方法
2015/11/04 PHP
Yii框架组件和事件行为管理详解
2016/05/20 PHP
如何简单地用YUI做JavaScript动画
2007/03/10 Javascript
用javascript实现自定义标签
2007/05/08 Javascript
csdn 博客中实现运行代码功能实现
2009/08/29 Javascript
利用js获取服务器时间的两个简单方法
2010/01/08 Javascript
统计出现最多的字符次数的js代码
2010/12/03 Javascript
JS实现图片产生波纹一样flash效果的方法
2015/02/27 Javascript
JS右下角广告窗口代码(可收缩、展开及关闭)
2015/09/04 Javascript
基于jQuery实现弹出可关闭遮罩提示框实例代码
2016/07/18 Javascript
vue开发心得和技巧分享
2016/10/27 Javascript
AngularJS过滤器filter用法实例分析
2016/11/04 Javascript
快速搭建React的环境步骤详解
2017/11/06 Javascript
详解React Native 采用Fetch方式发送跨域POST请求
2017/11/15 Javascript
JavaScript Array对象基本方法详解
2019/09/03 Javascript
JavaScript数组排序的六种常见算法总结
2020/08/18 Javascript
解决ant Design Search无法输入内容的问题
2020/10/29 Javascript
vue+echarts+datav大屏数据展示及实现中国地图省市县下钻功能
2020/11/16 Javascript
[01:07:17]EG vs Optic Supermajor 败者组 BO3 第一场 6.6
2018/06/07 DOTA
详解python发送各类邮件的主要方法
2016/12/22 Python
Python Pywavelet 小波阈值实例
2019/01/09 Python
python 一个figure上显示多个图像的实例
2019/07/08 Python
python实现贪吃蛇双人大战
2020/04/18 Python
如何从csv文件构建Tensorflow的数据集
2020/09/21 Python
Python logging自定义字段输出及打印颜色
2020/11/30 Python
thinkphp5 路由分发原理
2021/03/18 PHP
HTML5移动端开发中的Viewport标签及相关CSS用法解析
2016/04/15 HTML / CSS
英语专业学生个人求职信范文
2014/01/06 职场文书
计算机专业职业生涯规划范文
2014/01/19 职场文书
团日活动总结
2014/04/28 职场文书
营销部内勤岗位职责
2014/04/30 职场文书
大专毕业生自我鉴定范文(2篇)
2014/09/27 职场文书
2015年七一建党节活动方案
2015/05/05 职场文书
大学学习委员竞选稿
2015/11/20 职场文书
电脑开机弹出documents文件夹怎么回事?弹出documents文件夹解决方法
2022/04/08 数码科技