利用scikitlearn画ROC曲线实例


Posted in Python onJuly 02, 2020

一个完整的数据挖掘模型,最后都要进行模型评估,对于二分类来说,AUC,ROC这两个指标用到最多,所以 利用sklearn里面相应的函数进行模块搭建。

具体实现的代码可以参照下面博友的代码,评估svm的分类指标。注意里面的一些细节需要注意,一个是调用roc_curve 方法时,指明目标标签,否则会报错。

具体是这个参数的设置pos_label ,以前在unionbigdata实习时学到的。

重点是以下的代码需要根据实际改写:

mean_tpr = 0.0 
  mean_fpr = np.linspace(0, 1, 100) 
  all_tpr = []
  
  y_target = np.r_[train_y,test_y]
  cv = StratifiedKFold(y_target, n_folds=6)
 
    #画ROC曲线和计算AUC
    fpr, tpr, thresholds = roc_curve(test_y, predict,pos_label = 2)##指定正例标签,pos_label = ###########在数之联的时候学到的,要制定正例
    
    mean_tpr += interp(mean_fpr, fpr, tpr)     #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数 
    mean_tpr[0] = 0.0                #初始处为0 
    roc_auc = auc(fpr, tpr) 
    #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来 
    plt.plot(fpr, tpr, lw=1, label='ROC %s (area = %0.3f)' % (classifier, roc_auc))

然后是博友的参考代码:

# -*- coding: utf-8 -*- 
""" 
Created on Sun Apr 19 08:57:13 2015 
@author: shifeng 
""" 
print(__doc__) 
 
import numpy as np 
from scipy import interp 
import matplotlib.pyplot as plt 
 
from sklearn import svm, datasets 
from sklearn.metrics import roc_curve, auc 
from sklearn.cross_validation import StratifiedKFold 
 
############################################################################### 
# Data IO and generation,导入iris数据,做数据准备 
 
# import some data to play with 
iris = datasets.load_iris() 
X = iris.data 
y = iris.target 
X, y = X[y != 2], y[y != 2]#去掉了label为2,label只能二分,才可以。 
n_samples, n_features = X.shape 
 
# Add noisy features 
random_state = np.random.RandomState(0) 
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] 
 
############################################################################### 
# Classification and ROC analysis 
#分类,做ROC分析 
 
# Run classifier with cross-validation and plot ROC curves 
#使用6折交叉验证,并且画ROC曲线 
cv = StratifiedKFold(y, n_folds=6) 
classifier = svm.SVC(kernel='linear', probability=True, 
           random_state=random_state)#注意这里,probability=True,需要,不然预测的时候会出现异常。另外rbf核效果更好些。 
mean_tpr = 0.0 
mean_fpr = np.linspace(0, 1, 100) 
all_tpr = [] 
 
for i, (train, test) in enumerate(cv): 
  #通过训练数据,使用svm线性核建立模型,并对测试集进行测试,求出预测得分 
  probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test]) 
#  print set(y[train])           #set([0,1]) 即label有两个类别 
#  print len(X[train]),len(X[test])    #训练集有84个,测试集有16个 
#  print "++",probas_           #predict_proba()函数输出的是测试集在lael各类别上的置信度, 
#  #在哪个类别上的置信度高,则分为哪类 
  # Compute ROC curve and area the curve 
  #通过roc_curve()函数,求出fpr和tpr,以及阈值 
  fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1]) 
  mean_tpr += interp(mean_fpr, fpr, tpr)     #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数 
  mean_tpr[0] = 0.0                #初始处为0 
  roc_auc = auc(fpr, tpr) 
  #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来 
  plt.plot(fpr, tpr, lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc)) 
 
#画对角线 
plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck') 
 
mean_tpr /= len(cv)           #在mean_fpr100个点,每个点处插值插值多次取平均 
mean_tpr[-1] = 1.0           #坐标最后一个点为(1,1) 
mean_auc = auc(mean_fpr, mean_tpr)   #计算平均AUC值 
#画平均ROC曲线 
#print mean_fpr,len(mean_fpr) 
#print mean_tpr 
plt.plot(mean_fpr, mean_tpr, 'k--', 
     label='Mean ROC (area = %0.2f)' % mean_auc, lw=2) 
 
plt.xlim([-0.05, 1.05]) 
plt.ylim([-0.05, 1.05]) 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive Rate') 
plt.title('Receiver operating characteristic example') 
plt.legend(loc="lower right") 
plt.show()

补充知识:批量进行One-hot-encoder且进行特征字段拼接,并完成模型训练demo

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}
import org.apache.spark.ml.feature.VectorAssembler
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoostClassificationModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.PipelineModel
 
val data = (spark.read.format("csv")
 .option("sep", ",")
 .option("inferSchema", "true")
 .option("header", "true")
 .load("/Affairs.csv"))
 
data.createOrReplaceTempView("res1")
val affairs = "case when affairs>0 then 1 else 0 end as affairs,"
val df = (spark.sql("select " + affairs +
 "gender,age,yearsmarried,children,religiousness,education,occupation,rating" +
 " from res1 "))
 
val categoricals = df.dtypes.filter(_._2 == "StringType") map (_._1)
val indexers = categoricals.map(
 c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx")
)
 
val encoders = categoricals.map(
 c => new OneHotEncoder().setInputCol(s"${c}_idx").setOutputCol(s"${c}_enc").setDropLast(false)
)
 
val colArray_enc = categoricals.map(x => x + "_enc")
val colArray_numeric = df.dtypes.filter(_._2 != "StringType") map (_._1)
val final_colArray = (colArray_numeric ++ colArray_enc).filter(!_.contains("affairs"))
val vectorAssembler = new VectorAssembler().setInputCols(final_colArray).setOutputCol("features")
 
/*
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler))
pipeline.fit(df).transform(df)
*/
 
///
// Create an XGBoost Classifier 
val xgb = new XGBoostEstimator(Map("num_class" -> 2, "num_rounds" -> 5, "objective" -> "binary:logistic", "booster" -> "gbtree")).setLabelCol("affairs").setFeaturesCol("features")
 
// XGBoost paramater grid
val xgbParamGrid = (new ParamGridBuilder()
  .addGrid(xgb.round, Array(10))
  .addGrid(xgb.maxDepth, Array(10,20))
  .addGrid(xgb.minChildWeight, Array(0.1))
  .addGrid(xgb.gamma, Array(0.1))
  .addGrid(xgb.subSample, Array(0.8))
  .addGrid(xgb.colSampleByTree, Array(0.90))
  .addGrid(xgb.alpha, Array(0.0))
  .addGrid(xgb.lambda, Array(0.6))
  .addGrid(xgb.scalePosWeight, Array(0.1))
  .addGrid(xgb.eta, Array(0.4))
  .addGrid(xgb.boosterType, Array("gbtree"))
  .addGrid(xgb.objective, Array("binary:logistic")) 
  .build())
 
// Create the XGBoost pipeline
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler, xgb))
 
// Setup the binary classifier evaluator
val evaluator = (new BinaryClassificationEvaluator()
  .setLabelCol("affairs")
  .setRawPredictionCol("prediction")
  .setMetricName("areaUnderROC"))
 
// Create the Cross Validation pipeline, using XGBoost as the estimator, the
// Binary Classification evaluator, and xgbParamGrid for hyperparameters
val cv = (new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(xgbParamGrid)
  .setNumFolds(3)
  .setSeed(0))
 
 // Create the model by fitting the training data
val xgbModel = cv.fit(df)
 
 // Test the data by scoring the model
val results = xgbModel.transform(df)
 
// Print out a copy of the parameters used by XGBoost, attention pipeline
(xgbModel.bestModel.asInstanceOf[PipelineModel]
 .stages(5).asInstanceOf[XGBoostClassificationModel]
 .extractParamMap().toSeq.foreach(println))
results.select("affairs","prediction").show
 
println("---Confusion Matrix------")
results.stat.crosstab("affairs","prediction").show()
 
// What was the overall accuracy of the model, using AUC
val auc = evaluator.evaluate(results)
println("----AUC--------")
println("auc="+auc)

以上这篇利用scikitlearn画ROC曲线实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程中的for循环语句学习教程
Oct 14 Python
Python+matplotlib+numpy绘制精美的条形统计图
Jan 02 Python
python生成密码字典的方法
Jul 06 Python
python实现AES和RSA加解密的方法
Mar 28 Python
django的auth认证,authenticate和装饰器功能详解
Jul 25 Python
Python如何调用JS文件中的函数
Aug 16 Python
python实现画循环圆
Nov 23 Python
简单了解Python3 bytes和str类型的区别和联系
Dec 19 Python
Python接口自动化判断元素原理解析
Feb 24 Python
python文件排序的方法总结
Sep 13 Python
Python调用高德API实现批量地址转经纬度并写入表格的功能
Jan 12 Python
python使用tkinter实现透明窗体上绘制随机出现的小球(实例代码)
May 17 Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 #Python
keras用auc做metrics以及早停实例
Jul 02 #Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
You might like
一个很方便的 XML 类!!原创的噢
2006/10/09 PHP
PHP获取网页标题的3种实现方法代码实例
2014/04/11 PHP
十个PHP高级应用技巧果断收藏
2015/09/25 PHP
PHP 以POST方式提交XML、获取XML,解析XML详解及实例
2016/10/26 PHP
thinkPHP分页功能实例详解
2017/05/05 PHP
ExtJS PropertyGrid中使用Combobox选择值问题
2010/06/13 Javascript
Chrome中JSON.parse的特殊实现
2011/01/12 Javascript
JavaScript自定义方法实现trim()、Ltrim()、Rtrim()的功能
2013/11/03 Javascript
jquery删除提示框弹出是否删除对话框
2014/01/07 Javascript
JavaScript设计模式之观察者模式(发布者-订阅者模式)
2014/09/24 Javascript
jquery操作 iframe的方法
2014/12/03 Javascript
js判断某个方法是否存在实例代码
2015/01/10 Javascript
jQuery实现提交按钮点击后变成正在处理字样并禁止点击的方法
2015/03/24 Javascript
BootStrap 智能表单实战系列(十)自动完成组件的支持
2016/06/13 Javascript
JavaScript编写一个贪吃蛇游戏
2017/03/09 Javascript
微信小程序开发实现的IP地址查询功能示例
2019/03/28 Javascript
Vue实现点击按钮复制文本内容的例子
2019/11/09 Javascript
vue实现tab栏点击高亮效果
2020/08/19 Javascript
[01:30]我们共输赢 完美世界城市挑战赛开启全新赛季
2019/04/19 DOTA
Python中Continue语句的用法的举例详解
2015/05/14 Python
Python遍历目录并批量更换文件名和目录名的方法
2016/09/19 Python
浅谈Python中的作用域规则和闭包
2018/03/20 Python
对numpy的array和python中自带的list之间相互转化详解
2018/04/13 Python
基于python list对象中嵌套元组使用sort时的排序方法
2018/04/18 Python
Python 3.8新特征之asyncio REPL
2019/05/28 Python
浅谈tensorflow中张量的提取值和赋值
2020/01/19 Python
python线性插值解析
2020/07/05 Python
Tensorflow使用Anaconda、pycharm安装记录
2020/07/29 Python
如何利用python之wxpy模块玩转微信
2020/08/17 Python
python 对象真假值的实例(哪些视为False)
2020/12/11 Python
以设计师精品品质提供快速时尚:PopJulia
2018/01/09 全球购物
Merchant 1948澳大利亚:新西兰领先的鞋类和靴子供应商
2018/03/24 全球购物
Dr.Jart+美国官网:韩国药妆品牌
2019/01/18 全球购物
架构师岗位职责
2013/11/18 职场文书
面试复试通知单
2015/04/24 职场文书
详解CSS伪元素的妙用单标签之美
2021/05/25 HTML / CSS