pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python读写ini文件示例(python读写文件)
Mar 25 Python
玩转python爬虫之cookie使用方法
Feb 17 Python
python正则分析nginx的访问日志
Jan 17 Python
使用Python对Excel进行读写操作
Mar 30 Python
windows下Python实现将pdf文件转化为png格式图片的方法
Jul 21 Python
​如何愉快地迁移到 Python 3
Apr 28 Python
Django网络框架之HelloDjango项目创建教程
Jun 06 Python
python文件选择对话框的操作方法
Jun 27 Python
python实现复制大量文件功能
Aug 31 Python
基于virtualenv创建python虚拟环境过程图解
Mar 30 Python
解决Keras中Embedding层masking与Concatenate层不可调和的问题
Jun 18 Python
Python如何识别银行卡卡号?
Jun 10 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
用文本文件实现的动态实时发布新闻的程序
2006/10/09 PHP
php解析base64数据生成图片的方法
2016/12/06 PHP
php实现银联商务公众号+服务窗支付的示例代码
2019/10/12 PHP
php ActiveMQ的安装与使用方法图文教程
2020/02/23 PHP
ThinkPHP3.1.2 使用cli命令行模式运行的方法
2020/04/14 PHP
tp5.1 框架数据库高级查询技巧实例总结
2020/05/25 PHP
SyntaxHighlighter代码加色使用方法
2008/09/07 Javascript
关于可运行代码无法正常执行的使用说明
2010/05/13 Javascript
Jquery公告滚动+AJAX后台得到数据
2011/04/14 Javascript
js操作checkbox遇到的问题解决
2013/06/29 Javascript
JQuery Highcharts 动态生成图表的方法
2013/11/15 Javascript
jQuery 移动端artEditor富文本编辑器
2016/01/11 Javascript
高效利用Angular中内置服务$http、$location等
2016/03/22 Javascript
vue绑定的点击事件阻止冒泡的实例
2018/02/08 Javascript
vue2中使用sass并配置全局的sass样式变量的方法
2018/09/04 Javascript
Vue+webpack项目配置便于维护的目录结构教程详解
2018/10/14 Javascript
微信小程序自定义头部导航栏(组件化)
2019/11/15 Javascript
python实现将pvr格式转换成pvr.ccz的方法
2015/04/28 Python
python实现清屏的方法
2015/04/30 Python
深入解析Python中的lambda表达式的用法
2015/08/28 Python
动态规划之矩阵连乘问题Python实现方法
2017/11/27 Python
selenium在执行phantomjs的API并获取执行结果的方法
2018/12/17 Python
如何使用django的MTV开发模式返回一个网页
2019/07/22 Python
python针对mysql数据库的连接、查询、更新、删除操作示例
2019/09/11 Python
python cv2读取rtsp实时码流按时生成连续视频文件方式
2019/12/25 Python
python图片指定区域替换img.paste函数的使用
2020/04/09 Python
python 使用while循环输出*组成的菱形实例
2020/04/12 Python
非常漂亮的CSS3百叶窗焦点图动画
2016/02/24 HTML / CSS
辅导员评语
2014/05/04 职场文书
劳动竞赛活动总结
2014/05/05 职场文书
2015新学期家长寄语
2015/02/26 职场文书
2015年暑期社会实践活动总结
2015/03/27 职场文书
保安辞职申请书应该怎么写?
2019/07/15 职场文书
详解运行Python的神器Jupyter Notebook
2021/06/03 Python
梳理总结Python开发中需要摒弃的18个坏习惯
2022/01/22 Python
基于Python实现西西成语接龙小助手
2022/08/05 Golang