利用scikitlearn画ROC曲线实例


Posted in Python onJuly 02, 2020

一个完整的数据挖掘模型,最后都要进行模型评估,对于二分类来说,AUC,ROC这两个指标用到最多,所以 利用sklearn里面相应的函数进行模块搭建。

具体实现的代码可以参照下面博友的代码,评估svm的分类指标。注意里面的一些细节需要注意,一个是调用roc_curve 方法时,指明目标标签,否则会报错。

具体是这个参数的设置pos_label ,以前在unionbigdata实习时学到的。

重点是以下的代码需要根据实际改写:

mean_tpr = 0.0 
  mean_fpr = np.linspace(0, 1, 100) 
  all_tpr = []
  
  y_target = np.r_[train_y,test_y]
  cv = StratifiedKFold(y_target, n_folds=6)
 
    #画ROC曲线和计算AUC
    fpr, tpr, thresholds = roc_curve(test_y, predict,pos_label = 2)##指定正例标签,pos_label = ###########在数之联的时候学到的,要制定正例
    
    mean_tpr += interp(mean_fpr, fpr, tpr)     #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数 
    mean_tpr[0] = 0.0                #初始处为0 
    roc_auc = auc(fpr, tpr) 
    #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来 
    plt.plot(fpr, tpr, lw=1, label='ROC %s (area = %0.3f)' % (classifier, roc_auc))

然后是博友的参考代码:

# -*- coding: utf-8 -*- 
""" 
Created on Sun Apr 19 08:57:13 2015 
@author: shifeng 
""" 
print(__doc__) 
 
import numpy as np 
from scipy import interp 
import matplotlib.pyplot as plt 
 
from sklearn import svm, datasets 
from sklearn.metrics import roc_curve, auc 
from sklearn.cross_validation import StratifiedKFold 
 
############################################################################### 
# Data IO and generation,导入iris数据,做数据准备 
 
# import some data to play with 
iris = datasets.load_iris() 
X = iris.data 
y = iris.target 
X, y = X[y != 2], y[y != 2]#去掉了label为2,label只能二分,才可以。 
n_samples, n_features = X.shape 
 
# Add noisy features 
random_state = np.random.RandomState(0) 
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] 
 
############################################################################### 
# Classification and ROC analysis 
#分类,做ROC分析 
 
# Run classifier with cross-validation and plot ROC curves 
#使用6折交叉验证,并且画ROC曲线 
cv = StratifiedKFold(y, n_folds=6) 
classifier = svm.SVC(kernel='linear', probability=True, 
           random_state=random_state)#注意这里,probability=True,需要,不然预测的时候会出现异常。另外rbf核效果更好些。 
mean_tpr = 0.0 
mean_fpr = np.linspace(0, 1, 100) 
all_tpr = [] 
 
for i, (train, test) in enumerate(cv): 
  #通过训练数据,使用svm线性核建立模型,并对测试集进行测试,求出预测得分 
  probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test]) 
#  print set(y[train])           #set([0,1]) 即label有两个类别 
#  print len(X[train]),len(X[test])    #训练集有84个,测试集有16个 
#  print "++",probas_           #predict_proba()函数输出的是测试集在lael各类别上的置信度, 
#  #在哪个类别上的置信度高,则分为哪类 
  # Compute ROC curve and area the curve 
  #通过roc_curve()函数,求出fpr和tpr,以及阈值 
  fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1]) 
  mean_tpr += interp(mean_fpr, fpr, tpr)     #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数 
  mean_tpr[0] = 0.0                #初始处为0 
  roc_auc = auc(fpr, tpr) 
  #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来 
  plt.plot(fpr, tpr, lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc)) 
 
#画对角线 
plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck') 
 
mean_tpr /= len(cv)           #在mean_fpr100个点,每个点处插值插值多次取平均 
mean_tpr[-1] = 1.0           #坐标最后一个点为(1,1) 
mean_auc = auc(mean_fpr, mean_tpr)   #计算平均AUC值 
#画平均ROC曲线 
#print mean_fpr,len(mean_fpr) 
#print mean_tpr 
plt.plot(mean_fpr, mean_tpr, 'k--', 
     label='Mean ROC (area = %0.2f)' % mean_auc, lw=2) 
 
plt.xlim([-0.05, 1.05]) 
plt.ylim([-0.05, 1.05]) 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive Rate') 
plt.title('Receiver operating characteristic example') 
plt.legend(loc="lower right") 
plt.show()

补充知识:批量进行One-hot-encoder且进行特征字段拼接,并完成模型训练demo

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}
import org.apache.spark.ml.feature.VectorAssembler
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoostClassificationModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.PipelineModel
 
val data = (spark.read.format("csv")
 .option("sep", ",")
 .option("inferSchema", "true")
 .option("header", "true")
 .load("/Affairs.csv"))
 
data.createOrReplaceTempView("res1")
val affairs = "case when affairs>0 then 1 else 0 end as affairs,"
val df = (spark.sql("select " + affairs +
 "gender,age,yearsmarried,children,religiousness,education,occupation,rating" +
 " from res1 "))
 
val categoricals = df.dtypes.filter(_._2 == "StringType") map (_._1)
val indexers = categoricals.map(
 c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx")
)
 
val encoders = categoricals.map(
 c => new OneHotEncoder().setInputCol(s"${c}_idx").setOutputCol(s"${c}_enc").setDropLast(false)
)
 
val colArray_enc = categoricals.map(x => x + "_enc")
val colArray_numeric = df.dtypes.filter(_._2 != "StringType") map (_._1)
val final_colArray = (colArray_numeric ++ colArray_enc).filter(!_.contains("affairs"))
val vectorAssembler = new VectorAssembler().setInputCols(final_colArray).setOutputCol("features")
 
/*
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler))
pipeline.fit(df).transform(df)
*/
 
///
// Create an XGBoost Classifier 
val xgb = new XGBoostEstimator(Map("num_class" -> 2, "num_rounds" -> 5, "objective" -> "binary:logistic", "booster" -> "gbtree")).setLabelCol("affairs").setFeaturesCol("features")
 
// XGBoost paramater grid
val xgbParamGrid = (new ParamGridBuilder()
  .addGrid(xgb.round, Array(10))
  .addGrid(xgb.maxDepth, Array(10,20))
  .addGrid(xgb.minChildWeight, Array(0.1))
  .addGrid(xgb.gamma, Array(0.1))
  .addGrid(xgb.subSample, Array(0.8))
  .addGrid(xgb.colSampleByTree, Array(0.90))
  .addGrid(xgb.alpha, Array(0.0))
  .addGrid(xgb.lambda, Array(0.6))
  .addGrid(xgb.scalePosWeight, Array(0.1))
  .addGrid(xgb.eta, Array(0.4))
  .addGrid(xgb.boosterType, Array("gbtree"))
  .addGrid(xgb.objective, Array("binary:logistic")) 
  .build())
 
// Create the XGBoost pipeline
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler, xgb))
 
// Setup the binary classifier evaluator
val evaluator = (new BinaryClassificationEvaluator()
  .setLabelCol("affairs")
  .setRawPredictionCol("prediction")
  .setMetricName("areaUnderROC"))
 
// Create the Cross Validation pipeline, using XGBoost as the estimator, the
// Binary Classification evaluator, and xgbParamGrid for hyperparameters
val cv = (new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(xgbParamGrid)
  .setNumFolds(3)
  .setSeed(0))
 
 // Create the model by fitting the training data
val xgbModel = cv.fit(df)
 
 // Test the data by scoring the model
val results = xgbModel.transform(df)
 
// Print out a copy of the parameters used by XGBoost, attention pipeline
(xgbModel.bestModel.asInstanceOf[PipelineModel]
 .stages(5).asInstanceOf[XGBoostClassificationModel]
 .extractParamMap().toSeq.foreach(println))
results.select("affairs","prediction").show
 
println("---Confusion Matrix------")
results.stat.crosstab("affairs","prediction").show()
 
// What was the overall accuracy of the model, using AUC
val auc = evaluator.evaluate(results)
println("----AUC--------")
println("auc="+auc)

以上这篇利用scikitlearn画ROC曲线实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中处理unchecked未捕获异常实例
Jan 17 Python
Python操作串口的方法
Jun 17 Python
深入理解python中的浅拷贝和深拷贝
May 30 Python
详解Python中的Descriptor描述符类
Jun 14 Python
python 读写、创建 文件的方法(必看)
Sep 12 Python
利用python实现xml与数据库读取转换的方法
Jun 17 Python
Python线程之定位与销毁的实现
Feb 17 Python
Django REST framework 视图和路由详解
Jul 19 Python
Python使用循环神经网络解决文本分类问题的方法详解
Jan 16 Python
Python Websocket服务端通信的使用示例
Feb 25 Python
Python2 与Python3的版本区别实例分析
Mar 30 Python
OpenCV 之按位运算举例解析
Jun 19 Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 #Python
keras用auc做metrics以及早停实例
Jul 02 #Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
You might like
一步一步学习PHP(5) 类和对象
2010/02/16 PHP
PHP求最大子序列和的算法实现
2011/06/24 PHP
php获取远程图片体积大小的实例
2013/11/12 PHP
PHP通过加锁实现并发情况下抢码功能
2016/08/10 PHP
php实现HTML实体编号与非ASCII字符串相互转换类实例
2016/11/02 PHP
在云虚拟主机部署thinkphp5项目的步骤详解
2017/12/21 PHP
PHP后台实现微信小程序登录
2018/08/03 PHP
对textarea框的代码调试,而且功能上使用非常方便,酷
2006/06/30 Javascript
javascript操作文本框readOnly
2007/05/15 Javascript
Javascript 按位左移运算符使用介绍(
2014/02/04 Javascript
为jquery的ajaxfileupload增加附加参数的方法
2014/03/04 Javascript
深入浅析JavaScript系列(13):This? Yes,this!
2016/01/05 Javascript
JavaScript基础知识之方法汇总结
2016/01/24 Javascript
KnockoutJS 3.X API 第四章之事件event绑定
2016/10/10 Javascript
bootstrap PrintThis打印插件使用详解
2017/02/20 Javascript
[01:27:30]LGD vs Newbee 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/19 DOTA
python开发之thread实现布朗运动的方法
2015/11/11 Python
Python基于回溯法子集树模板解决选排问题示例
2017/09/07 Python
Python3实现统计单词表中每个字母出现频率的方法示例
2019/01/28 Python
python中使用ctypes调用so传参设置遇到的问题及解决方法
2019/06/19 Python
Django 批量插入数据的实现方法
2020/01/12 Python
关于matplotlib-legend 位置属性 loc 使用说明
2020/05/16 Python
详解Tensorflow不同版本要求与CUDA及CUDNN版本对应关系
2020/08/04 Python
不同浏览器对CSS3和HTML5的支持状况
2009/10/31 HTML / CSS
html5小技巧之通过document.head获取head元素
2014/06/04 HTML / CSS
html5的自定义data-*属性与jquery的data()方法的使用
2014/07/02 HTML / CSS
奥兰多迪士尼门票折扣:Undercover Tourist
2018/07/09 全球购物
护理专业个人求职简历的自我评价
2013/10/13 职场文书
教育局长自荐信范文
2013/12/22 职场文书
大学生饮食配送创业计划书
2014/01/04 职场文书
工程资料员岗位职责
2014/03/10 职场文书
商务日语专业毕业生自荐信
2014/03/27 职场文书
工作目标责任书
2014/07/23 职场文书
2015年秋季运动会广播稿
2015/08/19 职场文书
Java输出Hello World完美过程解析
2021/06/13 Java/Android
python字符串拼接.join()和拆分.split()详解
2021/11/23 Python