pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中文编码那些事
Jun 25 Python
Python解析xml中dom元素的方法
Mar 12 Python
python利用datetime模块计算时间差
Aug 04 Python
wxpython中自定义事件的实现与使用方法分析
Jul 21 Python
python实现简单爬虫功能的示例
Oct 24 Python
Python装饰器原理与简单用法实例分析
Apr 29 Python
Numpy截取指定范围内的数据方法
Nov 14 Python
Python numpy中矩阵的基本用法汇总
Feb 12 Python
了解不常见但是实用的Python技巧
May 23 Python
Python基于BeautifulSoup爬取京东商品信息
Jun 01 Python
为什么是 Python -m
Jun 19 Python
python turtle绘图
May 04 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
非集成环境的php运行环境(Apache配置、Mysql)搭建安装图文教程
2016/04/12 PHP
php错误日志简单配置方法
2016/07/11 PHP
PHP实现对xml的增删改查操作案例分析
2017/05/19 PHP
弹出模态框modal的实现方法及实例
2017/09/19 PHP
thinkphp集成前端脚手架Vue-cli的教程图解
2018/08/30 PHP
jQuery动画效果实现图片无缝连续滚动
2016/01/12 Javascript
jQuery实现的左右移动焦点图效果
2016/01/14 Javascript
JS实现左右无缝轮播图代码
2016/05/01 Javascript
JavaScript prototype属性详解
2016/10/25 Javascript
关于Sequelize连接查询时inlude中model和association的区别详解
2017/02/27 Javascript
vue子组件使用自定义事件向父组件传递数据
2017/05/27 Javascript
JS实现关键词高亮显示正则匹配
2018/06/22 Javascript
jQuery实现获取动态添加的标签对象示例
2018/06/28 jQuery
解决前后端分离 vue+springboot 跨域 session+cookie失效问题
2019/05/13 Javascript
bootstrap table插件动态加载表头
2019/07/19 Javascript
Vue Router的手写实现方法实现
2020/03/02 Javascript
详解Vue中的自定义指令
2020/12/07 Vue.js
python里对list中的整数求平均并排序
2014/09/12 Python
python检测是文件还是目录的方法
2015/07/03 Python
Python实现文件按照日期命名的方法
2015/07/09 Python
Python内置模块ConfigParser实现配置读写功能的方法
2018/02/12 Python
python编写弹球游戏的实现代码
2018/03/12 Python
python: line=f.readlines()消除line中\n的方法
2018/03/19 Python
对Python中的@classmethod用法详解
2018/04/21 Python
使用python远程操作linux过程解析
2019/12/04 Python
python实现字符串和数字拼接
2020/03/02 Python
python实现将两个文件夹合并至另一个文件夹(制作数据集)
2020/04/03 Python
计算s=f(f(-1.4))的值
2014/05/06 面试题
留学生如何写好自荐信
2013/12/27 职场文书
酒店保安领班职务说明书
2014/03/04 职场文书
推荐信怎么写
2014/05/09 职场文书
党员公开承诺书内容
2014/05/20 职场文书
政治学求职信
2014/06/03 职场文书
教师优秀党员事迹材料
2014/08/14 职场文书
学校做一个有道德的人活动方案
2014/08/23 职场文书
房屋认购协议书
2015/01/29 职场文书