pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python转码问题的解决方法
Oct 07 Python
用Python实现一个简单的能够发送带附件的邮件程序的教程
Apr 08 Python
python中set()函数简介及实例解析
Jan 09 Python
Python二叉树定义与遍历方法实例分析
May 25 Python
python简易远程控制单线程版
Jun 20 Python
python实现自动登录
Sep 17 Python
PyGame贪吃蛇的实现代码示例
Nov 21 Python
python字符串分割及字符串的一些常规方法
Jul 24 Python
解决python中的幂函数、指数函数问题
Nov 25 Python
基于python 凸包问题的解决
Apr 16 Python
Tensorflow实现将标签变为one-hot形式
May 22 Python
详解python 内存优化
Aug 17 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
详解PHP函数 strip_tags 处理字符串缺陷bug
2017/06/11 PHP
表单提交验证类
2006/07/14 Javascript
实现JavaScript中继承的三种方式
2009/10/16 Javascript
javascript 事件绑定问题
2011/01/01 Javascript
DB.ASP 用Javascript写ASP很灵活很好用很easy
2011/07/31 Javascript
jQuery构造函数init参数分析
2015/05/13 Javascript
JavaScript实现的数字与字符串转换功能示例
2017/08/23 Javascript
JavaScript程序设计高级算法之动态规划实例分析
2017/11/24 Javascript
jQuery图片加载失败替换默认图片方法汇总
2017/11/29 jQuery
浅谈webpack组织模块的原理
2018/03/10 Javascript
KOA+egg.js集成kafka消息队列的示例
2018/11/09 Javascript
理理Vue细节(推荐)
2019/04/16 Javascript
js基于canvas实现时钟组件
2021/02/07 Javascript
Python中asyncore的用法实例
2014/09/29 Python
用C++封装MySQL的API的教程
2015/05/06 Python
python:socket传输大文件示例
2017/01/18 Python
全面分析Python的优点和缺点
2018/02/07 Python
spark: RDD与DataFrame之间的相互转换方法
2018/06/07 Python
基于DATAFRAME中元素的读取与修改方法
2018/06/08 Python
wxPython实现分隔窗口
2019/11/19 Python
使用python代码进行身份证号校验的实现示例
2019/11/21 Python
python 解压、复制、删除 文件的实例代码
2020/02/26 Python
python中把元组转换为namedtuple方法
2020/12/09 Python
使用pandas读取表格数据并进行单行数据拼接的详细教程
2021/03/03 Python
手工制作的意大利礼服鞋:Ace Marks
2018/12/15 全球购物
Zavvi西班牙:电子游戏、极客服装、Blu-ray、Funko Pop等
2019/05/03 全球购物
业务主管岗位职责范本
2013/12/25 职场文书
艾滋病宣传标语
2014/06/25 职场文书
公共场所标语
2014/06/30 职场文书
商场租赁意向书
2014/07/30 职场文书
2014年党员加强作风建设思想汇报
2014/09/15 职场文书
销售经理工作检讨书
2015/02/19 职场文书
2015年学校综合治理工作总结
2015/07/20 职场文书
党员学习中国梦心得体会
2016/01/05 职场文书
Golang全局变量加锁的问题解决
2021/05/08 Golang
Django Paginator分页器的使用示例
2021/06/23 Python