pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Web框架Pylons中使用MongoDB的例子
Dec 03 Python
python利用datetime模块计算时间差
Aug 04 Python
玩转python爬虫之URLError异常处理
Feb 17 Python
pandas的唯一值、值计数以及成员资格的示例
Jul 25 Python
利用python numpy+matplotlib绘制股票k线图的方法
Jun 26 Python
Python Threading 线程/互斥锁/死锁/GIL锁
Jul 21 Python
基于python二叉树的构造和打印例子
Aug 09 Python
django处理select下拉表单实例(从model到前端到post到form)
Mar 13 Python
基于python实现音乐播放器代码实例
Jul 01 Python
Keras搭建自编码器操作
Jul 03 Python
Python中的With语句的使用及原理
Jul 29 Python
Python多线程实用方法以及共享变量资源竞争问题
Apr 12 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
php程序之die调试法 快速解决错误
2009/09/17 PHP
简单的cookie计数器实现源码
2013/06/07 PHP
百度站点地图(百度sitemap)生成方法分享
2014/01/09 PHP
Javascript与PHP验证用户输入URL地址是否正确
2014/10/09 PHP
PHP编程获取音频文件时长的方法【基于getid3类】
2017/04/20 PHP
php模拟post提交请求调用接口示例解析
2020/08/07 PHP
jQuery(非HTML5)可编辑表格实现代码
2012/12/11 Javascript
JQUERY 设置SELECT选中项代码
2014/02/07 Javascript
jquery实现textarea输入框限制字数的方法
2015/01/15 Javascript
JS获取图片高度宽度的方法分享
2015/04/17 Javascript
jquery插件EasyUI中form表单提交实例分享
2016/01/11 Javascript
体验jQuery和AngularJS的不同点及AngularJS的迷人之处
2016/02/02 Javascript
深入解析JavaScript中的arguments对象
2016/06/12 Javascript
利用JavaScript实现拖拽改变元素大小
2016/12/14 Javascript
jQuery实现动态文字搜索功能
2017/01/05 Javascript
详解Vue中使用v-for语句抛出错误的解决方案
2017/05/04 Javascript
javascript使用正则实现去掉字符串前面的所有0
2018/07/23 Javascript
node.js调用C++函数的方法示例
2018/09/21 Javascript
sortable+element 实现表格行拖拽的方法示例
2019/06/07 Javascript
js中的面向对象之对象常见创建方法详解
2019/12/16 Javascript
[46:48]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第三局
2016/02/25 DOTA
Python实现将Excel转换为json的方法示例
2017/08/05 Python
python监控linux内存并写入mongodb(推荐)
2017/09/11 Python
Python使用pip安装pySerial串口通讯模块
2018/04/20 Python
python变量的存储原理详解
2019/07/10 Python
Python callable内置函数原理解析
2020/03/05 Python
Python基于QQ邮箱实现SSL发送
2020/04/26 Python
css3动画鼠标放上图片逐渐变大鼠标离开图片逐渐缩小效果
2021/01/27 HTML / CSS
英国女士和男士时尚服装网上购物:Top Labels Online
2018/03/25 全球购物
幼儿园开学家长寄语
2014/01/19 职场文书
国家助学金获奖感言
2014/01/31 职场文书
欢送退休感言
2014/02/08 职场文书
2014年环保局工作总结
2014/12/11 职场文书
2015元旦晚会主持人开场白+结束语
2014/12/14 职场文书
有关保护环境的宣传标语100条
2019/08/07 职场文书
Python爬虫进阶之Beautiful Soup库详解
2021/04/29 Python