pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之字典,你还记得吗?
Sep 20 Python
python实现给字典添加条目的方法
Sep 25 Python
Python中使用md5sum检查目录中相同文件代码分享
Feb 02 Python
Python的collections模块中的OrderedDict有序字典
Jul 07 Python
Python在线运行代码助手
Jul 15 Python
python编程实现希尔排序
Apr 13 Python
Python中字典(dict)合并的四种方法总结
Aug 10 Python
python pandas 对时间序列文件处理的实例
Jun 22 Python
python使用folium库绘制地图点击框
Sep 21 Python
浅谈sklearn中predict与predict_proba区别
Jun 28 Python
python中pymysql包操作数据库方法
Apr 19 Python
Python使用pandas导入xlsx格式的excel文件内容操作代码
Dec 24 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
?生?D片??C字串
2006/12/06 PHP
phpmail类发送邮件函数代码
2012/02/20 PHP
php unset全局变量运用问题的深入解析
2013/06/17 PHP
简单的php缓存类分享     php缓存机制
2014/01/22 PHP
PHP错误和异长常处理总结
2014/03/06 PHP
smarty半小时快速上手入门教程
2014/10/27 PHP
php查找字符串中第一个非0的位置截取
2017/02/27 PHP
tp5.1 框架join方法用法实例分析
2020/05/26 PHP
JavaScript delete操作符应用实例
2009/01/13 Javascript
jquery 中多条件选择器,相对选择器,层次选择器的区别
2012/07/03 Javascript
超级简单的jquery操作表格方法
2014/12/15 Javascript
JavaScript记录光标在编辑器中位置的实现方法
2016/04/22 Javascript
Vue.js实现一个自定义分页组件vue-paginaiton
2016/09/05 Javascript
前端实现文件的断点续传(前端文件提交+后端PHP文件接收)
2016/11/04 Javascript
JavaScript 完成注册页面表单校验的实例
2017/08/19 Javascript
webpack中的热刷新与热加载的区别
2018/04/09 Javascript
详解ajax的data参数错误导致页面崩溃
2018/04/30 Javascript
Angular6 正则表达式允许输入部分中文字符
2018/09/10 Javascript
Vue toFixed保留两位小数的3种方式
2020/10/23 Javascript
[39:18]完美世界DOTA2联赛PWL S3 Forest vs LBZS 第二场 12.17
2020/12/19 DOTA
python基础知识小结之集合
2015/11/25 Python
Ubuntu 16.04 LTS中源码安装Python 3.6.0的方法教程
2016/12/27 Python
利用python实现微信头像加红色数字功能
2018/03/26 Python
Ubuntu16.04/树莓派Python3+opencv配置教程(分享)
2018/04/02 Python
Python Flask框架模板操作实例分析
2019/05/03 Python
python GUI实现小球满屏乱跑效果
2019/05/09 Python
使用pyecharts生成Echarts网页的实例
2019/08/12 Python
Python 爬虫实现增加播客访问量的方法实现
2019/10/31 Python
TensorFlow-gpu和opencv安装详细教程
2020/06/30 Python
新加坡时尚网上购物:Zalora新加坡
2016/07/26 全球购物
马来西亚在线药房:RoyalePharma
2019/12/01 全球购物
中学教师实习自我鉴定
2013/09/28 职场文书
小学生家长意见
2015/06/03 职场文书
小学教师师德培训心得体会
2016/01/09 职场文书
话题作文之生命的旋律
2019/12/17 职场文书
Java 超详细讲解数据结构中的堆的应用
2022/04/02 Java/Android