pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python MD5加密实例详解
Aug 02 Python
Django 前后台的数据传递的方法
Aug 08 Python
对numpy中的where方法嵌套使用详解
Oct 31 Python
PyCharm设置每行最大长度限制的方法
Jan 16 Python
Django对models里的objects的使用详解
Aug 17 Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 Python
python实现广度优先搜索过程解析
Oct 19 Python
Python接口测试get请求过程详解
Feb 28 Python
python with语句的原理与用法详解
Mar 30 Python
Python Selenium库的基本使用教程
Jan 04 Python
总结Python使用过程中的bug
Jun 18 Python
常用的Python代码调试工具总结
Jun 23 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
编写漂亮的代码 - 将后台程序与前端程序分开
2008/04/23 PHP
PHP CURL模拟GET及POST函数代码
2010/04/25 PHP
按上下级层次关系输出内容的PHP代码
2010/07/17 PHP
PHP return语句的另一个作用
2014/07/30 PHP
smarty模板引擎之分配数据类型
2015/03/30 PHP
PHP版微信第三方实现一键登录及获取用户信息的方法
2016/10/14 PHP
php-msf源码详解
2017/12/25 PHP
jquery弹出层类代码分享
2013/12/27 Javascript
第五章之BootStrap 栅格系统
2016/04/25 Javascript
JS实现显示带倒影的图片横排居中放大展示特效实例【测试可用】
2016/08/23 Javascript
VueJS全面解析
2016/11/10 Javascript
flag和jq on 的绑定多个对象和方法(必看)
2017/02/27 Javascript
JS+HTML5 FileReader对象用法示例
2017/04/07 Javascript
Thinkphp5微信小程序获取用户信息接口的实例详解
2017/09/26 Javascript
微信小程序之swiper轮播图中的图片自适应高度的方法
2018/04/23 Javascript
微信小程序swiper实现文字纵向轮播提示效果
2020/01/21 Javascript
利用PHP实现递归删除链表元素的方法示例
2020/10/23 Javascript
Python2.7基于笛卡尔积算法实现N个数组的排列组合运算示例
2017/11/23 Python
Python实现提取XML内容并保存到Excel中的方法
2018/09/01 Python
Python实现的读取文件内容并写入其他文件操作示例
2019/04/09 Python
Python字符串的一些操作方法总结
2019/06/10 Python
感知器基础原理及python实现过程详解
2019/09/30 Python
python给视频添加背景音乐并改变音量的具体方法
2020/07/19 Python
基于PyTorch中view的用法说明
2021/03/03 Python
thinkphp5 路由分发原理
2021/03/18 PHP
澳大利亚便宜隐形眼镜购买网站:QUICKLENS Australia
2018/10/06 全球购物
N.Peal官网:来自伦敦的高档羊绒品牌
2018/10/29 全球购物
Ooni英国官网:披萨烤箱
2020/05/31 全球购物
幼教毕业生自我鉴定
2014/01/12 职场文书
2014年大学生自我评价
2014/01/19 职场文书
2014升学宴答谢词
2014/01/26 职场文书
颐和园的导游词
2015/01/30 职场文书
紧急通知
2015/04/17 职场文书
医院病假条范文
2015/08/17 职场文书
JavaScript高级程序设计之基本引用类型
2021/11/17 Javascript
Golang解析JSON对象
2022/04/30 Golang