pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中map,reduce,filter和sorted函数的使用方法
Aug 17 Python
Python 中的with关键字使用详解
Sep 11 Python
教你用一行Python代码实现并行任务(附代码)
Feb 02 Python
Python线程池模块ThreadPoolExecutor用法分析
Dec 28 Python
Python实现字典按key或者value进行排序操作示例【sorted】
May 03 Python
Python定时任务APScheduler的实例实例详解
Jul 22 Python
opencv 获取rtsp流媒体视频的实现方法
Aug 23 Python
python基于event实现线程间通信控制
Jan 13 Python
Keras 使用 Lambda层详解
Jun 10 Python
python打开音乐文件的实例方法
Jul 21 Python
python利用pytesseract 实现本地识别图片文字
Dec 14 Python
深入理解python多线程编程
Apr 18 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
安装PHP可能遇到的问题“无法载入mysql扩展” 的解决方法
2007/04/16 PHP
yii框架builder、update、delete使用方法
2014/04/30 PHP
正确的PHP匹配UTF-8中文的正则表达式
2015/05/13 PHP
浅谈PHP5.6 与 PHP7.0 区别
2019/10/09 PHP
纯JavaScript实现的完美渐变弹出层效果代码
2010/04/02 Javascript
jQuery学习笔记之DOM对象和jQuery对象
2010/12/22 Javascript
B/S模式项目中常用的javascript汇总
2013/12/17 Javascript
多种方法实现360浏览器下禁止自动填写用户名密码
2014/06/16 Javascript
让checkbox不选中即将选中的checkbox不选中
2014/07/11 Javascript
JavaScript编程的单例设计模讲解
2015/11/10 Javascript
jQuery Ajax Post 回调函数不执行问题的解决方法
2016/08/15 Javascript
用jQuery实现优酷首页轮播图
2017/01/09 Javascript
Bootstrap中glyphicons-halflings-regular.woff字体报404错notfound的解决方法
2017/01/19 Javascript
jQuery使用unlock.js插件实现滑动解锁
2017/04/04 jQuery
从0到1构建vueSSR项目之node以及vue-cli3的配置
2019/03/07 Javascript
微信小程序实现3D轮播图效果(非swiper组件)
2019/09/21 Javascript
JavaScript多种图形实现代码实例
2020/06/28 Javascript
[06:25]第二届DOTA2亚洲邀请赛主赛事第二天比赛集锦.mp4
2017/04/03 DOTA
python实现Floyd算法
2018/01/03 Python
浅析python中numpy包中的argsort函数的使用
2018/08/30 Python
python下的opencv画矩形和文字注释的实现方法
2019/07/09 Python
在Python3 numpy中mean和average的区别详解
2019/08/24 Python
Python unittest discover批量执行代码实例
2020/09/08 Python
基于Python爬取股票数据过程详解
2020/10/21 Python
微软中国官方旗舰店:销售Surface、Xbox One、笔记本电脑、Office
2018/07/23 全球购物
英国设计师珠宝网站:Joshua James Jewellery
2020/03/01 全球购物
俄罗斯园林植物网上商店:Garshinka
2020/07/16 全球购物
电子商务专业学生的自我鉴定
2013/11/28 职场文书
大学四年规划书范文
2013/12/27 职场文书
优秀求职信范文分享
2014/01/26 职场文书
中专自我鉴定
2014/02/05 职场文书
淘宝店铺营销方案
2014/02/13 职场文书
给校长的建议书600字
2014/05/15 职场文书
不尊敬老师检讨书范文
2014/11/19 职场文书
MySQL深度分页(千万级数据量如何快速分页)
2021/07/25 MySQL
HTML 里 img 元素的 src 和 srcset 属性的区别详解
2023/05/21 HTML / CSS