pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python optparse模块使用实例
Apr 09 Python
详细解读Python中解析XML数据的方法
Oct 15 Python
基于Linux系统中python matplotlib画图的中文显示问题的解决方法
Jun 15 Python
Python计算库numpy进行方差/标准方差/样本标准方差/协方差的计算
Dec 28 Python
强悍的Python读取大文件的解决方案
Feb 16 Python
Ubuntu18.04中Python2.7与Python3.6环境切换
Jun 14 Python
django的ORM操作 删除和编辑实现详解
Jul 24 Python
Django框架视图函数设计示例
Jul 29 Python
python  logging日志打印过程解析
Oct 22 Python
python模块hashlib(加密服务)知识点讲解
Nov 25 Python
Python 实现平台类游戏添加跳跃功能
Mar 27 Python
python搜索算法原理及实例讲解
Nov 18 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
别人整理的服务器变量:$_SERVER
2006/10/20 PHP
php cli 小技巧
2013/06/03 PHP
PHP+Mysql+Ajax+JS实现省市区三级联动
2014/05/23 PHP
PHP中的魔术方法总结和使用实例
2015/05/11 PHP
分享一个漂亮的php验证码类
2016/09/29 PHP
mac pecl 安装php7.1扩展教程
2019/10/17 PHP
动态加载iframe
2006/06/16 Javascript
JavaScript 浏览器验证代码(来自discuz)
2010/07/17 Javascript
JSChart轻量级图形报表工具(内置函数中文参考)
2010/10/11 Javascript
js中typeof的用法汇总
2013/12/12 Javascript
jQuery中ajax的load()方法用法实例
2014/12/26 Javascript
input输入框鼠标焦点提示信息
2015/03/17 Javascript
跟我学习javascript的定时器
2015/11/19 Javascript
学习JavaScript设计模式之策略模式
2016/01/12 Javascript
bootstrap3 兼容IE8浏览器!
2016/05/02 Javascript
Javascript中将变量转换为字符串的三种方法
2017/09/19 Javascript
mac上配置Android环境变量的方法
2018/07/08 Javascript
JS实现获取当前所在周的周六、周日示例分析
2019/05/11 Javascript
操作按钮悬浮固定在微信小程序底部的实现代码
2019/08/02 Javascript
微信小程序保持session会话的方法
2020/03/20 Javascript
[03:55]TI9战队采访——TNC Predator
2019/08/22 DOTA
python封装对象实现时间效果
2020/04/23 Python
在java中如何定义一个抽象属性示例详解
2017/08/18 Python
django加载本地html的方法
2018/05/27 Python
python根据list重命名文件夹里的所有文件实例
2018/10/25 Python
python实现图片彩色转化为素描
2019/01/15 Python
Python企业编码生成系统总体系统设计概述
2019/07/26 Python
python主线程与子线程的结束顺序实例解析
2019/12/17 Python
Python3搭建http服务器的实现代码
2020/02/11 Python
使用python求解二次规划的问题
2020/02/29 Python
Python selenium 加载并保存QQ群成员,去除其群主、管理员信息的示例代码
2020/05/28 Python
使用SVG实现提示框功能的示例代码
2020/06/05 HTML / CSS
中间件分为哪几类
2016/09/18 面试题
2019中小学生安全过暑期倡议书
2019/06/24 职场文书
导游词之杭州西湖
2019/09/19 职场文书
使用HBuilder制作一个简单的HTML5网页
2022/07/07 HTML / CSS