pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python Django模板的使用方法(图文)
Nov 04 Python
linux系统使用python监控apache服务器进程脚本分享
Jan 15 Python
如何在VSCode上轻松舒适的配置Python的方法步骤
Oct 28 Python
Python使用微信接入图灵机器人过程解析
Nov 04 Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 Python
opencv 查找连通区域 最大面积实例
Jun 04 Python
python 8种必备的gui库
Aug 27 Python
使用anaconda安装pytorch的实现步骤
Sep 03 Python
简单了解Python字典copy与赋值的区别
Sep 16 Python
python闭包与引用以及需要注意的陷阱
Sep 18 Python
pytorch 中forward 的用法与解释说明
Feb 26 Python
python可视化分析绘制带趋势线的散点图和边缘直方图
Jun 25 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
纯真IP数据库的应用 IP地址转化成十进制
2009/06/14 PHP
php中使用preg_match_all匹配文章中的图片
2013/02/06 PHP
jQuery-ui中自动完成实现方法
2010/06/10 Javascript
Javascript 面试题随笔
2011/03/31 Javascript
jquery获取table中的某行全部td的内容方法
2013/03/08 Javascript
文本有关的样式和jQuery求对象的高宽问题分别说明
2013/08/30 Javascript
推荐25个超炫的jQuery网格插件
2014/11/28 Javascript
JS实现至少包含字母、大小写数字、字符的密码等级的两种方法
2015/02/03 Javascript
JavaScript实现表格点击排序的方法
2015/05/11 Javascript
js倒计时抢购实例
2015/12/20 Javascript
jQuery实现下拉加载功能实例代码
2016/04/01 Javascript
网页前端登录js按Enter回车键实现登陆的两种方法
2016/05/10 Javascript
js实现截图保存图片功能的代码示例
2017/02/16 Javascript
jQuery图片瀑布流的简单实现代码
2017/03/15 Javascript
jQuery中hover方法搭配css的hover选择器,实现选中元素突出显示方法
2017/05/08 jQuery
JavaScript canvas实现雨滴特效
2021/01/10 Javascript
python 根据正则表达式提取指定的内容实例详解
2016/12/04 Python
Python递归函数定义与用法示例
2017/06/02 Python
Python绘制3d螺旋曲线图实例代码
2017/12/20 Python
浅析python中numpy包中的argsort函数的使用
2018/08/30 Python
python将一组数分成每3个一组的实例
2018/11/14 Python
python numpy元素的区间查找方法
2018/11/14 Python
在Python中合并字典模块ChainMap的隐藏坑【推荐】
2019/06/27 Python
Pytorch实现LSTM和GRU示例
2020/01/14 Python
浅析Python 条件控制语句
2020/07/15 Python
美国在线纱线商店:Darn Good Yarn
2019/03/20 全球购物
专业实习自我鉴定
2013/10/29 职场文书
业务员简历自我评价
2014/03/06 职场文书
法定代表人身份证明书
2014/09/10 职场文书
大学四年个人总结
2015/03/03 职场文书
签订劳动合同通知书
2015/04/16 职场文书
2016感恩母亲节校园广播稿
2015/12/17 职场文书
三年级作文之小小梦想
2019/12/06 职场文书
php字符串倒叙
2021/04/01 PHP
浅谈python中的多态
2021/06/15 Python
Java 在生活中的 10 大应用
2021/11/02 Java/Android