pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python通过colorama模块在控制台输出彩色文字的方法
Mar 19 Python
python的dataframe和matrix的互换方法
Apr 11 Python
转换科学计数法的数值字符串为decimal类型的方法
Jul 16 Python
Python学习笔记之读取文件、OS模块、异常处理、with as语法示例
Jun 04 Python
通过pycharm使用git的步骤(图文详解)
Jun 13 Python
python打印9宫格、25宫格等奇数格 满足横竖斜相加和相等
Jul 19 Python
Python常用模块os.path之文件及路径操作方法
Dec 03 Python
如何将 awk 脚本移植到 Python
Dec 09 Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 Python
Python钉钉报警及Zabbix集成钉钉报警的示例代码
Aug 17 Python
python speech模块的使用方法
Sep 09 Python
Python如何批量生成和调用变量
Nov 21 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
php下使用SMTP发邮件的代码
2008/01/10 PHP
php二维数组排序详解
2013/11/06 PHP
php中PDO方式实现数据库的增删改查
2015/05/17 PHP
PHP使用trim函数去除字符串左右空格及特殊字符实例
2016/01/07 PHP
php注册系统和使用Xajax即时验证用户名是否被占用
2017/08/31 PHP
PHP守护进程化在C和PHP环境下的实现
2017/11/21 PHP
PHP封装的page分页类定义与用法完整示例
2018/12/24 PHP
php7新特性的理解和比较总结
2019/04/14 PHP
Javascript - HTML的request类
2007/01/09 Javascript
用js实现随机返回数组的一个元素
2007/08/13 Javascript
学习javascript的闭包,原型,和匿名函数之旅
2015/10/18 Javascript
JavaScript学习小结之使用canvas画“哆啦A梦”时钟
2016/07/24 Javascript
JavaScript实现网页头部进度条刷新
2017/04/16 Javascript
Vue实现textarea固定输入行数与添加下划线样式的思路详解
2018/06/28 Javascript
JavaScript引用类型Array实例分析
2018/07/24 Javascript
原生JS实现$.param() 函数的方法
2018/08/10 Javascript
vue2.0 路由模式mode="history"的作用
2018/10/18 Javascript
vue+render+jsx实现可编辑动态多级表头table的实例代码
2020/04/01 Javascript
Vue-router中hash模式与history模式的区别详解
2020/12/15 Vue.js
Python中变量交换的例子
2014/08/25 Python
Python下调用Linux的Shell命令的方法
2018/06/12 Python
pytorch进行上采样的种类实例
2020/02/18 Python
Python 实现打印单词的菱形字符图案
2020/04/12 Python
python_matplotlib改变横坐标和纵坐标上的刻度(ticks)方式
2020/05/16 Python
python实现密码验证合格程序的思路详解
2020/06/01 Python
Python使用xpath实现图片爬取
2020/09/16 Python
Python浮点型(float)运算结果不正确的解决方案
2020/09/22 Python
中外合拍动画首获奥斯卡提名,“上海出品”《飞奔去月球》能否拿下最终大奖?
2021/03/16 国漫
汤米巴哈马官方网站:Tommy Bahama
2017/05/13 全球购物
餐饮业会计岗位职责
2013/12/19 职场文书
船舶专业个人求职信范文
2014/01/02 职场文书
企业宣传工作方案
2014/06/02 职场文书
小学亲子活动总结
2014/07/01 职场文书
婚礼庆典答谢词
2015/01/20 职场文书
男方婚前保证书
2015/02/28 职场文书
MySQL索引失效十种场景与优化方案
2023/05/08 MySQL