pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作CouchDB的方法
Oct 08 Python
Python字符串替换实例分析
May 11 Python
django轻松使用富文本编辑器CKEditor的方法
Mar 30 Python
Django实现表单验证
Sep 08 Python
python实现文件助手中查看微信撤回消息
Apr 29 Python
Python实现九宫格式的朋友圈功能内附“马云”朋友圈
May 07 Python
python3.8 微信发送服务器监控报警消息代码实现
Nov 05 Python
Python 判断时间是否在时间区间内的实例
May 16 Python
520使用Python实现“我爱你”表白
May 20 Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 Python
python 怎样进行内存管理
Nov 10 Python
Python WebSocket长连接心跳与短连接的示例
Nov 24 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
人尽可用的Windows技巧小贴士之下篇
2007/03/22 PHP
PHP 替换模板变量实现步骤
2009/08/24 PHP
php dirname(__FILE__) 获取当前文件的绝对路径
2011/06/28 PHP
PHP读取文件的常见几种方法
2016/11/03 PHP
PHP判断json格式是否正确的实现代码
2017/09/20 PHP
用于节点操作的API,颠覆原生操作HTML DOM节点的API
2010/12/11 Javascript
基于jQuery的动态表格插件
2011/03/28 Javascript
jquery.artwl.thickbox.js  一个非常简单好用的jQuery弹出层插件
2012/03/01 Javascript
jQuery之$(document).ready()使用介绍
2012/04/05 Javascript
浏览器解析js生成的html出现样式问题的解决方法
2012/04/16 Javascript
javascript-表格排序(降序/反序)实现介绍(附图)
2013/05/30 Javascript
jquery获取对象的方法足以应付常见的各种类型的对象
2014/05/14 Javascript
js获取网页可见区域、正文以及屏幕分辨率的高度
2014/05/15 Javascript
javascript获取重复次数最多的字符
2015/07/08 Javascript
js下拉选择框与输入框联动实现添加选中值到输入框的方法
2015/08/17 Javascript
jquery结婚电子请柬特效源码分享
2015/08/21 Javascript
JS实现弹出居中的模式窗口示例
2016/06/20 Javascript
Bootstrap Table的使用总结
2016/10/08 Javascript
浅谈箭头函数写法在ReactJs中的使用
2017/08/22 Javascript
简述Angular 5 快速入门
2017/11/04 Javascript
Vue的路由动态重定向和导航守卫实例
2018/03/17 Javascript
vue动态设置页面title的方法实例
2020/08/23 Javascript
[08:44]和酒神一起战斗 DOTA2教你做大人
2014/03/27 DOTA
用Python将动态GIF图片倒放播放的方法
2016/11/02 Python
python 获取微信好友列表的方法(微信web)
2019/02/21 Python
python+mysql实现个人论文管理系统
2019/10/25 Python
Pandas之缺失数据的实现
2021/01/06 Python
CSS3动画特效在活动页中的应用
2020/01/21 HTML / CSS
英国评分最高的女性剃须刀订阅盒:FFS Beauty
2018/01/25 全球购物
阿玛尼美妆俄罗斯官网:Giorgio Armani Beauty RU
2020/07/19 全球购物
SQL Server数据库笔试题和答案
2016/02/04 面试题
《赠汪伦》教学反思
2014/04/12 职场文书
《灰雀》教学反思
2016/02/19 职场文书
用python实现监控视频人数统计
2021/05/21 Python
详解Python中的for循环
2022/04/30 Python
深入理解 Golang 的字符串
2022/05/04 Golang