pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python时间整形转标准格式的示例分享
Feb 14 Python
Python中装饰器兼容加括号和不加括号的写法详解
Jul 05 Python
Python DataFrame设置/更改列表字段/元素类型的方法
Jun 09 Python
Django添加KindEditor富文本编辑器的使用
Oct 24 Python
[机器视觉]使用python自动识别验证码详解
May 16 Python
用vue.js组件模拟v-model指令实例方法
Jul 05 Python
python实现列表中最大最小值输出的示例
Jul 09 Python
Python 中list ,set,dict的大规模查找效率对比详解
Oct 11 Python
Django后端发送小程序微信模板消息示例(服务通知)
Dec 17 Python
简单了解python filter、map、reduce的区别
Jan 14 Python
keras 指定程序在某块卡上训练实例
Jun 22 Python
python使用scapy模块实现ARP扫描的过程
Jan 21 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
基于数据库的在线人数,日访问量等统计
2006/10/09 PHP
php的一个登录的类 [推荐]
2007/03/16 PHP
LotusPhp笔记之:基于ObjectUtil组件的使用分析
2013/05/06 PHP
Function eregi is deprecated (解决方法)
2013/06/21 PHP
PHP实现查询手机归属地的方法详解
2017/04/28 PHP
php合并数组并保留键值的实现方法
2018/03/12 PHP
php实现映射操作实例详解
2019/10/02 PHP
PHP7 安装event扩展的实现方法
2019/10/08 PHP
Json对象替换字符串占位符实现代码
2010/11/17 Javascript
jQuery1.6 正式版发布并提供下载
2011/05/05 Javascript
Nodejs+express+html5 实现拖拽上传
2014/08/08 NodeJs
使用jQuery在对象中缓存选择器的简单方法
2015/06/30 Javascript
JS对HTML表格进行增删改操作
2016/08/22 Javascript
JavaScript中push(),join() 函数 实例详解
2016/09/06 Javascript
JS正则匹配中文的方法示例
2017/01/06 Javascript
微信小程序 限制1M的瘦身技巧与方法详解
2017/01/06 Javascript
js实现下拉框效果(select)
2017/03/28 Javascript
加载 vue 远程代码的组件实例详解
2017/11/20 Javascript
js实现二级菜单点击显示当前内容效果
2018/04/28 Javascript
JavaScript数组、json对象、eval()函数用法实例分析
2019/02/21 Javascript
Vue3新特性之在Composition API中使用CSS Modules
2020/07/13 Javascript
解决await在forEach中不起作用的问题
2021/02/25 Javascript
[47:03]Ti4第二日主赛事败者组 LGD vs iG 2
2014/07/21 DOTA
Python实现批量下载图片的方法
2015/07/08 Python
Python实现快速多线程ping的方法
2015/07/15 Python
Python PyQt5实现的简易计算器功能示例
2017/08/23 Python
利用 Python ElementTree 生成 xml的实例
2020/03/06 Python
Selenium启动Chrome时配置选项详解
2020/03/18 Python
CentOS 7如何实现定时执行python脚本
2020/06/24 Python
联想香港官方网站及网店:Lenovo香港
2018/04/13 全球购物
L’urv官网:精品女性运动服品牌
2019/07/07 全球购物
我的中国心演讲稿
2014/09/04 职场文书
幼儿园中班教师个人总结
2015/02/05 职场文书
个人总结格式范文
2015/03/09 职场文书
2015中秋祝酒词
2015/08/12 职场文书
一篇文章带你复习java知识点
2021/06/28 Java/Android