pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
举例讲解Python程序与系统shell交互的方式
Apr 09 Python
对python中return和print的一些理解
Aug 18 Python
python监控键盘输入实例代码
Feb 09 Python
Python图像处理之识别图像中的文字(实例讲解)
May 10 Python
Python 经典面试题 21 道【不可错过】
Sep 21 Python
Python3内置模块之json编解码方法小结【推荐】
Dec 09 Python
django 类视图的使用方法详解
Jul 24 Python
python实现的按要求生成手机号功能示例
Oct 08 Python
python中return的返回和执行实例
Dec 24 Python
Python列表切片常用操作实例解析
Mar 10 Python
如何基于python实现不邻接植花
May 01 Python
python使用openpyxl操作excel的方法步骤
May 28 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
用PHP连mysql和oracle数据库性能比较
2006/10/09 PHP
PHP自动更新新闻DIY
2006/10/09 PHP
php数组函数序列之in_array() - 查找数组中是否存在指定值
2011/11/07 PHP
php递归创建目录的方法
2015/02/02 PHP
JS多物体 任意值 链式 缓冲运动
2012/08/10 Javascript
如何获取select下拉框的值(option没有及有value属性)
2013/11/08 Javascript
javascript创建cookie、读取cookie
2016/03/31 Javascript
javascript回调函数的概念理解与用法分析
2017/05/27 Javascript
浅谈关于angularJs中使用$.ajax的注意点
2017/08/12 Javascript
详解基于Angular4+ server render(服务端渲染)开发教程
2017/08/28 Javascript
深入理解Node.js中通用基础设计模式
2017/09/19 Javascript
2种简单的js倒计时方式
2017/10/20 Javascript
基于JavaScript实现报警器提示音效果
2017/10/27 Javascript
vue实现验证码按钮倒计时功能
2018/04/10 Javascript
React props和state属性的具体使用方法
2018/04/12 Javascript
element-ui 上传图片后清空图片显示的实例
2018/09/04 Javascript
微信小程序实现底部导航
2018/11/05 Javascript
js如何验证密码强度
2020/03/18 Javascript
es6函数中的作用域实例分析
2020/04/18 Javascript
[01:06:12]VP vs NIP 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/17 DOTA
在python的WEB框架Flask中使用多个配置文件的解决方法
2014/04/18 Python
以windows service方式运行Python程序的方法
2015/06/03 Python
django 实现电子支付功能的示例代码
2018/07/25 Python
浅谈pytorch grad_fn以及权重梯度不更新的问题
2019/08/20 Python
django框架单表操作之增删改实例分析
2019/12/16 Python
python 获取字典键值对的实现
2020/11/12 Python
python pillow库的基础使用教程
2021/01/13 Python
施华洛世奇日本官网:SWAROVSKI日本
2018/05/04 全球购物
甜点店创业计划书
2014/01/27 职场文书
公证委托书模板
2014/04/03 职场文书
个人股份合作协议书
2014/10/24 职场文书
工会工作个人总结
2015/03/03 职场文书
行政人事主管岗位职责
2015/04/11 职场文书
党员公开承诺书(2016最新版)
2016/03/24 职场文书
MySQL数据迁移相关总结
2021/04/29 MySQL
详解php中流行的rpc框架
2021/05/29 PHP