利用scikitlearn画ROC曲线实例


Posted in Python onJuly 02, 2020

一个完整的数据挖掘模型,最后都要进行模型评估,对于二分类来说,AUC,ROC这两个指标用到最多,所以 利用sklearn里面相应的函数进行模块搭建。

具体实现的代码可以参照下面博友的代码,评估svm的分类指标。注意里面的一些细节需要注意,一个是调用roc_curve 方法时,指明目标标签,否则会报错。

具体是这个参数的设置pos_label ,以前在unionbigdata实习时学到的。

重点是以下的代码需要根据实际改写:

mean_tpr = 0.0 
  mean_fpr = np.linspace(0, 1, 100) 
  all_tpr = []
  
  y_target = np.r_[train_y,test_y]
  cv = StratifiedKFold(y_target, n_folds=6)
 
    #画ROC曲线和计算AUC
    fpr, tpr, thresholds = roc_curve(test_y, predict,pos_label = 2)##指定正例标签,pos_label = ###########在数之联的时候学到的,要制定正例
    
    mean_tpr += interp(mean_fpr, fpr, tpr)     #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数 
    mean_tpr[0] = 0.0                #初始处为0 
    roc_auc = auc(fpr, tpr) 
    #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来 
    plt.plot(fpr, tpr, lw=1, label='ROC %s (area = %0.3f)' % (classifier, roc_auc))

然后是博友的参考代码:

# -*- coding: utf-8 -*- 
""" 
Created on Sun Apr 19 08:57:13 2015 
@author: shifeng 
""" 
print(__doc__) 
 
import numpy as np 
from scipy import interp 
import matplotlib.pyplot as plt 
 
from sklearn import svm, datasets 
from sklearn.metrics import roc_curve, auc 
from sklearn.cross_validation import StratifiedKFold 
 
############################################################################### 
# Data IO and generation,导入iris数据,做数据准备 
 
# import some data to play with 
iris = datasets.load_iris() 
X = iris.data 
y = iris.target 
X, y = X[y != 2], y[y != 2]#去掉了label为2,label只能二分,才可以。 
n_samples, n_features = X.shape 
 
# Add noisy features 
random_state = np.random.RandomState(0) 
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] 
 
############################################################################### 
# Classification and ROC analysis 
#分类,做ROC分析 
 
# Run classifier with cross-validation and plot ROC curves 
#使用6折交叉验证,并且画ROC曲线 
cv = StratifiedKFold(y, n_folds=6) 
classifier = svm.SVC(kernel='linear', probability=True, 
           random_state=random_state)#注意这里,probability=True,需要,不然预测的时候会出现异常。另外rbf核效果更好些。 
mean_tpr = 0.0 
mean_fpr = np.linspace(0, 1, 100) 
all_tpr = [] 
 
for i, (train, test) in enumerate(cv): 
  #通过训练数据,使用svm线性核建立模型,并对测试集进行测试,求出预测得分 
  probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test]) 
#  print set(y[train])           #set([0,1]) 即label有两个类别 
#  print len(X[train]),len(X[test])    #训练集有84个,测试集有16个 
#  print "++",probas_           #predict_proba()函数输出的是测试集在lael各类别上的置信度, 
#  #在哪个类别上的置信度高,则分为哪类 
  # Compute ROC curve and area the curve 
  #通过roc_curve()函数,求出fpr和tpr,以及阈值 
  fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1]) 
  mean_tpr += interp(mean_fpr, fpr, tpr)     #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数 
  mean_tpr[0] = 0.0                #初始处为0 
  roc_auc = auc(fpr, tpr) 
  #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来 
  plt.plot(fpr, tpr, lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc)) 
 
#画对角线 
plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck') 
 
mean_tpr /= len(cv)           #在mean_fpr100个点,每个点处插值插值多次取平均 
mean_tpr[-1] = 1.0           #坐标最后一个点为(1,1) 
mean_auc = auc(mean_fpr, mean_tpr)   #计算平均AUC值 
#画平均ROC曲线 
#print mean_fpr,len(mean_fpr) 
#print mean_tpr 
plt.plot(mean_fpr, mean_tpr, 'k--', 
     label='Mean ROC (area = %0.2f)' % mean_auc, lw=2) 
 
plt.xlim([-0.05, 1.05]) 
plt.ylim([-0.05, 1.05]) 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive Rate') 
plt.title('Receiver operating characteristic example') 
plt.legend(loc="lower right") 
plt.show()

补充知识:批量进行One-hot-encoder且进行特征字段拼接,并完成模型训练demo

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}
import org.apache.spark.ml.feature.VectorAssembler
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoostClassificationModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.PipelineModel
 
val data = (spark.read.format("csv")
 .option("sep", ",")
 .option("inferSchema", "true")
 .option("header", "true")
 .load("/Affairs.csv"))
 
data.createOrReplaceTempView("res1")
val affairs = "case when affairs>0 then 1 else 0 end as affairs,"
val df = (spark.sql("select " + affairs +
 "gender,age,yearsmarried,children,religiousness,education,occupation,rating" +
 " from res1 "))
 
val categoricals = df.dtypes.filter(_._2 == "StringType") map (_._1)
val indexers = categoricals.map(
 c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx")
)
 
val encoders = categoricals.map(
 c => new OneHotEncoder().setInputCol(s"${c}_idx").setOutputCol(s"${c}_enc").setDropLast(false)
)
 
val colArray_enc = categoricals.map(x => x + "_enc")
val colArray_numeric = df.dtypes.filter(_._2 != "StringType") map (_._1)
val final_colArray = (colArray_numeric ++ colArray_enc).filter(!_.contains("affairs"))
val vectorAssembler = new VectorAssembler().setInputCols(final_colArray).setOutputCol("features")
 
/*
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler))
pipeline.fit(df).transform(df)
*/
 
///
// Create an XGBoost Classifier 
val xgb = new XGBoostEstimator(Map("num_class" -> 2, "num_rounds" -> 5, "objective" -> "binary:logistic", "booster" -> "gbtree")).setLabelCol("affairs").setFeaturesCol("features")
 
// XGBoost paramater grid
val xgbParamGrid = (new ParamGridBuilder()
  .addGrid(xgb.round, Array(10))
  .addGrid(xgb.maxDepth, Array(10,20))
  .addGrid(xgb.minChildWeight, Array(0.1))
  .addGrid(xgb.gamma, Array(0.1))
  .addGrid(xgb.subSample, Array(0.8))
  .addGrid(xgb.colSampleByTree, Array(0.90))
  .addGrid(xgb.alpha, Array(0.0))
  .addGrid(xgb.lambda, Array(0.6))
  .addGrid(xgb.scalePosWeight, Array(0.1))
  .addGrid(xgb.eta, Array(0.4))
  .addGrid(xgb.boosterType, Array("gbtree"))
  .addGrid(xgb.objective, Array("binary:logistic")) 
  .build())
 
// Create the XGBoost pipeline
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler, xgb))
 
// Setup the binary classifier evaluator
val evaluator = (new BinaryClassificationEvaluator()
  .setLabelCol("affairs")
  .setRawPredictionCol("prediction")
  .setMetricName("areaUnderROC"))
 
// Create the Cross Validation pipeline, using XGBoost as the estimator, the
// Binary Classification evaluator, and xgbParamGrid for hyperparameters
val cv = (new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(xgbParamGrid)
  .setNumFolds(3)
  .setSeed(0))
 
 // Create the model by fitting the training data
val xgbModel = cv.fit(df)
 
 // Test the data by scoring the model
val results = xgbModel.transform(df)
 
// Print out a copy of the parameters used by XGBoost, attention pipeline
(xgbModel.bestModel.asInstanceOf[PipelineModel]
 .stages(5).asInstanceOf[XGBoostClassificationModel]
 .extractParamMap().toSeq.foreach(println))
results.select("affairs","prediction").show
 
println("---Confusion Matrix------")
results.stat.crosstab("affairs","prediction").show()
 
// What was the overall accuracy of the model, using AUC
val auc = evaluator.evaluate(results)
println("----AUC--------")
println("auc="+auc)

以上这篇利用scikitlearn画ROC曲线实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python生成随机密码或随机字符串的方法
Jul 03 Python
解决Django模板无法使用perms变量问题的方法
Sep 10 Python
用 Python 连接 MySQL 的几种方式详解
Apr 04 Python
tensorflow更改变量的值实例
Jul 30 Python
python 自定义对象的打印方法
Jan 12 Python
对python:threading.Thread类的使用方法详解
Jan 31 Python
python浪漫表白源码
Apr 05 Python
python实现键盘输入的实操方法
Jul 16 Python
Python学习笔记之Django创建第一个数据库模型的方法
Aug 07 Python
python向图片里添加文字
Nov 26 Python
Python collections模块的使用方法
Oct 09 Python
python机器学习实现oneR算法(以鸢尾data为例)
Mar 03 Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 #Python
keras用auc做metrics以及早停实例
Jul 02 #Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
You might like
全国FM电台频率大全 - 23 四川省
2020/03/11 无线电
php计划任务之ignore_user_abort函数实现方法
2015/01/08 PHP
PHP实现的随机IP函数【国内IP段】
2016/07/20 PHP
弹出广告特效代码(一个IP只弹出一次)
2007/05/11 Javascript
动态修改DOM 里面的 id 属性的弊端分析
2008/09/03 Javascript
javascript prototype原型操作笔记
2009/12/07 Javascript
jQuery EasyUI 开源插件套装 完全替代ExtJS
2010/03/24 Javascript
IE6 fixed的完美解决方案
2011/03/31 Javascript
基于Jquery插件Uploadify实现实时显示进度条上传图片
2020/03/26 Javascript
jQuery根据name属性进行查找的用法分析
2016/06/23 Javascript
JS JSOP跨域请求实例详解
2016/07/04 Javascript
jquery 实时监听输入框值变化的完美方法(必看)
2017/01/26 Javascript
Angular.JS中指令ng-if的注意事项小结
2017/06/21 Javascript
对vue下点击事件传参和不传参的区别详解
2018/09/15 Javascript
ES6 新增的创建数组的方法(小结)
2019/08/01 Javascript
vue动态子组件的两种实现方式
2019/09/01 Javascript
微信小程序实现文件预览
2020/10/22 Javascript
Python中的is和==比较两个对象的两种方法
2017/09/06 Python
基于数据归一化以及Python实现方式
2018/07/11 Python
Python中format()格式输出全解
2019/04/12 Python
500行Python代码打造刷脸考勤系统
2019/06/03 Python
一文秒懂python读写csv xml json文件各种骚操作
2019/07/04 Python
django 类视图的使用方法详解
2019/07/24 Python
Python 合并多个TXT文件并统计词频的实现
2019/08/23 Python
python运用pygame库实现双人弹球小游戏
2019/11/25 Python
Python计算机视觉里的IOU计算实例
2020/01/17 Python
Python HTMLTestRunner可视化报告实现过程解析
2020/04/10 Python
python Socket网络编程实现C/S模式和P2P
2020/06/22 Python
支票、地址标签、包装纸和慰问卡:Current Catalog
2018/01/30 全球购物
澳大利亚领先的折扣药房:Chemist Direct(有中文站)
2018/11/24 全球购物
《毛主席在花山》教学反思
2014/04/20 职场文书
中国梦演讲稿3分钟
2014/08/19 职场文书
以幸福为主题的活动方案
2014/08/22 职场文书
详解前端任务构建利器Gulp.js使用指南
2021/04/30 Javascript
python的netCDF4批量处理NC格式文件的操作方法
2022/03/21 Python
世界无敌的ICOM IC-R9500宽频接收机
2022/03/25 无线电