pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
横向对比分析Python解析XML的四种方式
Mar 30 Python
Python使用requests发送POST请求实例代码
Jan 25 Python
Django中的Signal代码详解
Feb 05 Python
Django后台获取前端post上传的文件方法
May 28 Python
对django views中 request, response的常用操作详解
Jul 17 Python
django解决订单并发问题【推荐】
Jul 31 Python
python 定时器每天就执行一次的实现代码
Aug 14 Python
Pytorch evaluation每次运行结果不同的解决
Jan 02 Python
Python并发请求下限制QPS(每秒查询率)的实现代码
Jun 05 Python
你需要学会的8个Python列表技巧
Jun 24 Python
python中pyplot基础图标函数整理
Nov 10 Python
Python加密与解密模块hashlib与hmac
Jun 05 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
通过5个php实例细致说明传值与传引用的区别
2012/08/08 PHP
将文本输入框内容加入表中的js代码
2013/08/18 Javascript
教你用jquery实现iframe自适应高度
2014/06/11 Javascript
jQuery给动态添加的元素绑定事件的方法
2015/03/09 Javascript
基于jquery实现瀑布流布局
2020/06/28 Javascript
浅析JS获取url中的参数实例代码
2016/06/14 Javascript
js传值后台中文出现乱码的解决方法
2016/06/30 Javascript
Async Validator 异步验证使用说明
2017/07/03 Javascript
使用socket.io制做简易WEB聊天室
2018/01/02 Javascript
浅析微信扫码登录原理(小结)
2018/10/29 Javascript
浅谈React碰到v-if
2018/11/04 Javascript
jQuery实现购物车的总价计算和总价传值功能
2018/11/28 jQuery
layui实现下拉框三级联动
2019/07/26 Javascript
超轻量级的js时间库miment使用解析
2019/08/02 Javascript
js实现跟随鼠标移动的小球
2019/08/26 Javascript
解决vue scoped html样式无效的问题
2020/10/24 Javascript
详解nginx配置vue h5 history去除#号
2020/11/09 Javascript
[01:26]DOTA2荣耀之路2:iG,China
2018/05/24 DOTA
[01:03:41]完美世界DOTA2联赛PWL S3 DLG vs Phoenix 第一场 12.17
2020/12/19 DOTA
寻找网站后台地址的python脚本
2014/09/01 Python
跟老齐学Python之使用Python查询更新数据库
2014/11/25 Python
python协程用法实例分析
2015/06/04 Python
深入理解python中函数传递参数是值传递还是引用传递
2017/11/07 Python
python 对给定可迭代集合统计出现频率,并排序的方法
2018/10/18 Python
opencv python 图像轮廓/检测轮廓/绘制轮廓的方法
2019/07/03 Python
python GUI库图形界面开发之PyQt5信号与槽的高级使用技巧(自定义信号与槽)详解与实例
2020/03/06 Python
Python语法垃圾回收机制原理解析
2020/03/25 Python
Python爬虫定时计划任务的几种常见方法(推荐)
2021/01/15 Python
ALDO加拿大官网:加拿大女鞋品牌
2018/12/22 全球购物
AMAVII眼镜官网:时尚和设计师太阳镜
2019/05/05 全球购物
《分一分》教学反思
2014/04/13 职场文书
自荐信格式模板
2015/03/27 职场文书
2015年终个人政治思想工作总结
2015/11/24 职场文书
2016年植树节红领巾广播稿
2015/12/17 职场文书
2016教师学习党章心得体会
2016/01/15 职场文书
MySQL事务的隔离级别详情
2022/07/15 MySQL