pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python 学习笔记
Dec 27 Python
Python删除指定目录下过期文件的2个脚本分享
Apr 10 Python
python连接MySQL数据库实例分析
May 12 Python
Python使用sftp实现上传和下载功能(实例代码)
Mar 14 Python
Python中几种导入模块的方式总结
Apr 27 Python
Django使用详解:ORM 的反向查找(related_name)
May 30 Python
详解django使用include无法跳转的解决方法
Mar 19 Python
在pycharm中使用matplotlib.pyplot 绘图时报错的解决
Jun 01 Python
python 基于wx实现音乐播放
Nov 24 Python
Python 语言实现六大查找算法
Jun 30 Python
详解python的异常捕获
Mar 03 Python
解决IDEA翻译插件Translation报错更新TTK失败不能使用
Apr 24 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
PHP实现深度优先搜索算法(DFS,Depth First Search)详解
2017/09/16 PHP
JavaScript中的事件处理
2008/01/16 Javascript
两个select之间option的互相添加操作(jquery实现)
2009/11/12 Javascript
JavaScript全局函数使用简单说明
2011/03/11 Javascript
Javascript 八进制转义字符(8进制)
2011/04/08 Javascript
jquery内置验证(validate)使用方法示例(表单验证)
2013/12/04 Javascript
让JavaScript的Alert弹出框失效的方法禁止弹出警告框
2014/09/03 Javascript
Javascript基础教程之argument 详解
2015/01/18 Javascript
PHP 数组current和next用法分享
2015/03/05 Javascript
基于javascript数组实现图片轮播
2016/05/02 Javascript
laypage分页控件使用实例详解
2016/05/19 Javascript
js实现可控制左右方向的无缝滚动效果
2016/05/29 Javascript
JS实现重新加载当前页面
2016/11/29 Javascript
基于JavaScript实现窗口拖动效果
2017/01/18 Javascript
javascript事件的绑定基础实例讲解(34)
2017/02/14 Javascript
JavaScript中的遍历详解(多种遍历)
2017/04/07 Javascript
JavaScript实现滑动导航栏效果
2017/08/30 Javascript
js实现加载页面就自动触发超链接的示例
2017/08/31 Javascript
vue中封装axios并实现api接口的统一管理
2020/12/25 Vue.js
Python三级目录展示的实现方法
2016/09/28 Python
把csv文件转化为数组及数组的切片方法
2018/07/04 Python
python GUI库图形界面开发之PyQt5信号与槽的高级使用技巧装饰器信号与槽详细使用方法与实例
2020/03/06 Python
python 两种方法删除空文件夹
2020/09/29 Python
Python3爬虫ChromeDriver的安装实例
2021/02/06 Python
使用HTML5 IndexDB存储图像和文件的示例
2018/11/05 HTML / CSS
门卫班长岗位职责
2013/12/15 职场文书
写给女生的道歉信
2014/01/08 职场文书
工作态度不端正检讨书
2014/10/04 职场文书
防灾减灾宣传标语
2014/10/07 职场文书
房产分割协议书范文
2014/11/21 职场文书
人事聘任通知
2015/04/21 职场文书
观后感的写法
2015/06/19 职场文书
2016年小学生寒假总结
2015/10/10 职场文书
2016大学生毕业实习心得体会
2016/01/23 职场文书
python turtle绘制多边形和跳跃和改变速度特效
2022/03/16 Python
win10+RTX3050ti+TensorFlow+cudn+cudnn配置深度学习环境的方法
2022/06/25 Servers