pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python代码解决RenderView窗口not found问题
Aug 28 Python
Python排序搜索基本算法之冒泡排序实例分析
Dec 09 Python
selenium+python自动化测试之鼠标和键盘事件
Jan 23 Python
Python中单线程、多线程和多进程的效率对比实验实例
May 14 Python
pyqt 实现在Widgets中显示图片和文字的方法
Jun 13 Python
对Python3之方法的覆盖与super函数详解
Jun 26 Python
Python使用百度api做人脸对比的方法
Aug 28 Python
利用Python裁切tiff图像且读取tiff,shp文件的实例
Mar 10 Python
Django-migrate报错问题解决方案
Apr 21 Python
MxNet预训练模型到Pytorch模型的转换方式
May 25 Python
python多线程semaphore实现线程数控制的示例
Aug 10 Python
python数字图像处理:图像的绘制
Jun 28 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
无数据库的详细域名查询程序PHP版(2)
2006/10/09 PHP
php下过滤HTML代码的函数
2007/12/10 PHP
PHP 开发环境配置(Zend Studio)
2010/04/28 PHP
PHP连接SQLServer2005的方法
2015/01/27 PHP
微信JSSDK分享功能图文实例详解
2019/04/08 PHP
如何在centos8自定义目录安装php7.3
2019/11/28 PHP
在网页中屏蔽快捷键
2006/09/06 Javascript
(仅IE下有效)关于checkbox 三态
2007/05/12 Javascript
Javascript延迟执行实现方法(setTimeout)
2010/12/30 Javascript
数组方法解决JS字符串连接性能问题有争议
2011/01/12 Javascript
javascript+html5实现绘制圆环的方法
2015/07/28 Javascript
js随机生成字母数字组合的字符串 随机动画数字
2015/09/02 Javascript
JS给swf传参数的实现方法
2016/09/13 Javascript
BootstrapValidator超详细教程(推荐)
2016/12/07 Javascript
微信小程序技巧之show内容展示,上传文件编码问题
2017/01/23 Javascript
快速使用node.js进行web开发详解
2017/04/26 Javascript
使用Node.js搭建静态资源服务详细教程
2017/08/02 Javascript
JS沙箱模式实例分析
2017/09/04 Javascript
javascript Function函数理解与实战
2017/12/01 Javascript
Vue.js点击切换按钮改变内容的实例讲解
2018/08/22 Javascript
基于iview的router常用控制方式
2019/05/30 Javascript
vue.js 2.0实现简单分页效果
2019/07/29 Javascript
原生js实现瀑布流效果
2020/03/09 Javascript
Openlayers测量距离与面积的实现方法
2020/09/25 Javascript
Python中的filter()函数的用法
2015/04/27 Python
Python创建普通菜单示例【基于win32ui模块】
2018/05/09 Python
教你一步步利用python实现贪吃蛇游戏
2019/06/27 Python
Python环境下安装PyGame和PyOpenGL的方法
2020/03/25 Python
Python二元算术运算常用方法解析
2020/09/15 Python
python中requests模拟登录的三种方式(携带cookie/session进行请求网站)
2020/11/17 Python
Python  Asyncio模块实现的生产消费者模型的方法
2021/03/01 Python
整理HTML5的一些新特性与Canvas的常用属性
2016/01/29 HTML / CSS
企业爱岗敬业演讲稿
2014/09/04 职场文书
卡特教练观后感
2015/06/08 职场文书
四年级语文教学反思
2016/03/03 职场文书
JS轻量级函数式编程实现XDM二
2022/06/16 Javascript