pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python中对list去重的多种方法
Sep 18 Python
python爬虫爬取网页表格数据
Mar 07 Python
Python Numpy 数组的初始化和基本操作
Mar 13 Python
python基础教程项目二之画幅好画
Apr 02 Python
详解python多线程、锁、event事件机制的简单使用
Apr 27 Python
[原创]Python入门教程5. 字典基本操作【定义、运算、常用函数】
Nov 01 Python
python画图系列之个性化显示x轴区段文字的实例
Dec 13 Python
使用Python画股票的K线图的方法步骤
Jun 28 Python
Python-jenkins模块获取jobs的执行状态操作
May 12 Python
使用python脚本自动生成K8S-YAML的方法示例
Jul 12 Python
Python中logging日志记录到文件及自动分割的操作代码
Aug 05 Python
基于Python模拟浏览器发送http请求
Nov 06 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
ThinkPHP 模板引擎使用详解
2017/05/07 PHP
PHP折半(二分)查找算法实例分析
2018/05/12 PHP
Yii框架连表查询操作示例
2019/09/06 PHP
JavaScript 应用类库代码
2008/06/02 Javascript
分享14个很酷的jQuery导航菜单插件
2011/04/25 Javascript
JS实现向表格行添加新单元格的方法
2015/03/30 Javascript
angularjs学习笔记之完整的项目结构
2015/09/26 Javascript
JavaScript知识点总结(十一)之js中的Object类详解
2016/05/31 Javascript
关于网页中的无缝滚动的js代码
2016/06/09 Javascript
JavaScript结合Bootstrap仿微信后台多图文界面管理
2016/07/22 Javascript
H5手机端多文件上传预览插件
2017/04/21 Javascript
页面缩放兼容性处理方法(zoom,Firefox火狐浏览器)
2017/08/29 Javascript
详解Node.js利用node-git-server快速搭建git服务器
2017/09/27 Javascript
JavaScript单线程和任务队列原理解析
2020/02/04 Javascript
移动端JS实现拖拽两种方法解析
2020/10/12 Javascript
python函数返回多个值的示例方法
2013/12/04 Python
Python后台开发Django的教程详解(启动)
2019/04/08 Python
使用python爬取微博数据打造一颗“心”
2019/06/28 Python
对tensorflow中的strides参数使用详解
2020/01/04 Python
python设置环境变量的作用整理
2020/02/17 Python
如何利用Python给自己的头像加一个小国旗(小月饼)
2020/10/02 Python
The Kooples美国官方网站:为情侣提供的法国当代时尚品牌
2019/01/03 全球购物
eBay意大利购物网站:eBay.it
2019/09/04 全球购物
详解如何解决使用JSON.stringify时遇到的循环引用问题
2021/03/23 Javascript
音乐系毕业生自荐信
2013/10/27 职场文书
珍珠奶茶店创业计划书
2014/01/11 职场文书
给学校的建议书
2014/03/12 职场文书
小学开学典礼主持词
2014/03/19 职场文书
教师工作失职检讨书
2014/09/18 职场文书
2015年大学生村官工作总结
2015/04/21 职场文书
应急管理工作总结2015
2015/05/04 职场文书
硕士学位申请报告
2015/05/15 职场文书
单位接收证明格式
2015/06/18 职场文书
搞笑婚礼主持词开场白
2015/11/24 职场文书
Idea连接MySQL数据库出现中文乱码的问题
2021/04/14 MySQL
Nginx stream 配置代理(Nginx TCP/UDP 负载均衡)
2021/11/17 Servers