pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现堆排序算法代码
Jun 05 Python
多线程爬虫批量下载pcgame图片url 保存为xml的实现代码
Jan 17 Python
python基于urllib实现按照百度音乐分类下载mp3的方法
May 25 Python
Python 爬虫学习笔记之多线程爬虫
Sep 21 Python
TensorFlow实现创建分类器
Feb 06 Python
Python对数据进行插值和下采样的方法
Jul 03 Python
PyCharm+PySpark远程调试的环境配置的方法
Nov 29 Python
Python爬虫headers处理及网络超时问题解决方案
Jun 19 Python
解决PyCharm无法使用lxml库的问题(图解)
Dec 22 Python
python基础学习之递归函数知识总结
May 26 Python
Python实现抖音热搜定时爬取功能
Mar 16 Python
关于Python使用turtle库画任意图的问题
Apr 01 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
php表单提交问题的解决方法
2011/04/12 PHP
Laravel框架表单验证详解
2014/09/04 PHP
php常用数学函数汇总
2014/11/21 PHP
PHP内存缓存功能memcached示例
2016/10/19 PHP
PHP函数rtrim()使用中的怪异现象分析
2017/02/24 PHP
js字符编码函数区别分析
2008/06/05 Javascript
JS上传图片前的限制包括(jpg jpg gif及大小高宽)等
2012/12/19 Javascript
jQuery getJSON()+.ashx 实现分页(改进版)
2013/03/28 Javascript
JavaScript怎么判断图片是否加载完成以便获取其尺寸
2014/05/08 Javascript
jQuery调取jSon数据并展示的方法
2015/01/29 Javascript
Bootstrap导航栏各元素操作方法(表单、按钮、文本)
2015/12/28 Javascript
Bootstrap分页插件之Bootstrap Paginator实例详解
2016/10/15 Javascript
js实现适合新闻类图片的轮播效果
2017/02/05 Javascript
利用imgareaselect辅助后台实现图片上传裁剪
2017/03/02 Javascript
给Easyui-Datebox设置隐藏或者不可用的解决方法
2017/05/26 Javascript
SVG动画vivus.js库使用小结(实例代码)
2017/09/14 Javascript
使用JS判断移动端手机横竖屏状态
2018/07/30 Javascript
利用Pandas读取文件路径或文件名称包含中文的csv文件方法
2018/07/04 Python
Python2.7环境Flask框架安装简明教程【已测试】
2018/07/13 Python
Python 中Django验证码功能的实现代码
2019/06/20 Python
django echarts饼图数据动态加载的实例
2019/08/12 Python
docker-py 用Python调用Docker接口的方法
2019/08/30 Python
使用python求解二次规划的问题
2020/02/29 Python
python 安装教程之Pycharm安装及配置字体主题,换行,自动更新
2020/03/13 Python
Otticanet澳大利亚:最顶尖的世界名牌眼镜, 能得到打折季的价格
2018/08/23 全球购物
什么情况下你必须要把一个类定义为abstract的
2013/01/06 面试题
医学生临床实习自我评价
2014/03/07 职场文书
公益广告标语
2014/06/19 职场文书
2015年财务工作总结范文
2015/03/31 职场文书
初中班主任工作总结2015
2015/05/13 职场文书
2015年党员个人工作总结
2015/05/13 职场文书
升学宴学生致辞
2015/07/27 职场文书
如何制定一份可行的计划!
2019/06/21 职场文书
Vue详细的入门笔记
2021/05/10 Vue.js
Redis Cluster集群动态扩容的实现
2021/07/15 Redis
MySQL 数据表操作
2022/05/04 MySQL