pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python写个小监控
Jan 27 Python
Python网络爬虫出现乱码问题的解决方法
Jan 05 Python
解决安装tensorflow遇到无法卸载numpy 1.8.0rc1的问题
Jun 13 Python
Sanic框架蓝图用法实例分析
Jul 17 Python
python使用Plotly绘图工具绘制散点图、线形图
Apr 02 Python
用Python实现二叉树、二叉树非递归遍历及绘制的例子
Aug 09 Python
python 遍历pd.Series的index和value
Nov 26 Python
Pycharm 2020年最新激活码(亲测有效)
Sep 18 Python
Matplotlib使用Cursor实现UI定位的示例代码
Mar 12 Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 Python
Python使用struct处理二进制(pack和unpack用法)
Nov 12 Python
Python读取ini配置文件传参的简单示例
Jan 05 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
简单的PHP多图上传小程序代码
2011/07/17 PHP
php获取从百度搜索进入网站的关键词的详细代码
2014/01/08 PHP
js 代码集(学习js的朋友可以看下)
2009/07/22 Javascript
js中方法重载如何实现?以及函数的参数问题
2013/08/01 Javascript
使用js简单实现了tree树菜单
2013/11/20 Javascript
Javascript中数组sort和reverse用法分析
2014/12/30 Javascript
js的回调函数详解
2015/01/05 Javascript
浅谈javascript中的加减时间
2016/07/12 Javascript
js实现贪吃蛇小游戏(容易理解)
2017/01/22 Javascript
Vue监听数组变化源码解析
2017/03/09 Javascript
vue.js事件处理器是什么
2017/03/20 Javascript
vue-router 中router-view不能渲染的解决方法
2017/05/23 Javascript
Vue2.0仿饿了么webapp单页面应用详细步骤
2018/07/08 Javascript
详解微信小程序实现WebSocket心跳重连
2018/07/31 Javascript
element-ui组件table实现自定义筛选功能的示例代码
2019/03/15 Javascript
使用vue2.6实现抖音【时间轮盘】屏保效果附源码
2019/04/24 Javascript
javascript获取select值的方法完整实例
2019/06/20 Javascript
原生js+ajax分页组件
2020/01/30 Javascript
Vue开发中遇到的跨域问题及解决方法
2020/02/11 Javascript
解决vue的touchStart事件及click事件冲突问题
2020/07/21 Javascript
vue中echarts的用法及与elementui-select的协同绑定操作
2020/11/17 Vue.js
[06:24]DOTA2亚洲邀请赛小组赛第三日 TOP10精彩集锦
2015/02/01 DOTA
Python RabbitMQ消息队列实现rpc
2018/05/30 Python
使用python远程操作linux过程解析
2019/12/04 Python
Python合并2个字典成1个新字典的方法(9种)
2019/12/19 Python
Pycharm安装python库的方法
2020/11/24 Python
HTML5添加禁止缩放功能
2017/11/03 HTML / CSS
找到不普通的东西:Bonanza
2016/10/20 全球购物
C语言面试题
2015/10/30 面试题
init进程的作用
2012/04/12 面试题
法学专业毕业生自荐信范文
2013/12/18 职场文书
语文教学感言
2014/02/06 职场文书
小学安全教育材料
2014/02/17 职场文书
化学专业毕业生求职信
2014/07/28 职场文书
2014最新自愿离婚协议书范本
2014/11/19 职场文书
在职人员跳槽求职信
2015/03/20 职场文书