利用scikitlearn画ROC曲线实例


Posted in Python onJuly 02, 2020

一个完整的数据挖掘模型,最后都要进行模型评估,对于二分类来说,AUC,ROC这两个指标用到最多,所以 利用sklearn里面相应的函数进行模块搭建。

具体实现的代码可以参照下面博友的代码,评估svm的分类指标。注意里面的一些细节需要注意,一个是调用roc_curve 方法时,指明目标标签,否则会报错。

具体是这个参数的设置pos_label ,以前在unionbigdata实习时学到的。

重点是以下的代码需要根据实际改写:

mean_tpr = 0.0 
  mean_fpr = np.linspace(0, 1, 100) 
  all_tpr = []
  
  y_target = np.r_[train_y,test_y]
  cv = StratifiedKFold(y_target, n_folds=6)
 
    #画ROC曲线和计算AUC
    fpr, tpr, thresholds = roc_curve(test_y, predict,pos_label = 2)##指定正例标签,pos_label = ###########在数之联的时候学到的,要制定正例
    
    mean_tpr += interp(mean_fpr, fpr, tpr)     #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数 
    mean_tpr[0] = 0.0                #初始处为0 
    roc_auc = auc(fpr, tpr) 
    #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来 
    plt.plot(fpr, tpr, lw=1, label='ROC %s (area = %0.3f)' % (classifier, roc_auc))

然后是博友的参考代码:

# -*- coding: utf-8 -*- 
""" 
Created on Sun Apr 19 08:57:13 2015 
@author: shifeng 
""" 
print(__doc__) 
 
import numpy as np 
from scipy import interp 
import matplotlib.pyplot as plt 
 
from sklearn import svm, datasets 
from sklearn.metrics import roc_curve, auc 
from sklearn.cross_validation import StratifiedKFold 
 
############################################################################### 
# Data IO and generation,导入iris数据,做数据准备 
 
# import some data to play with 
iris = datasets.load_iris() 
X = iris.data 
y = iris.target 
X, y = X[y != 2], y[y != 2]#去掉了label为2,label只能二分,才可以。 
n_samples, n_features = X.shape 
 
# Add noisy features 
random_state = np.random.RandomState(0) 
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] 
 
############################################################################### 
# Classification and ROC analysis 
#分类,做ROC分析 
 
# Run classifier with cross-validation and plot ROC curves 
#使用6折交叉验证,并且画ROC曲线 
cv = StratifiedKFold(y, n_folds=6) 
classifier = svm.SVC(kernel='linear', probability=True, 
           random_state=random_state)#注意这里,probability=True,需要,不然预测的时候会出现异常。另外rbf核效果更好些。 
mean_tpr = 0.0 
mean_fpr = np.linspace(0, 1, 100) 
all_tpr = [] 
 
for i, (train, test) in enumerate(cv): 
  #通过训练数据,使用svm线性核建立模型,并对测试集进行测试,求出预测得分 
  probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test]) 
#  print set(y[train])           #set([0,1]) 即label有两个类别 
#  print len(X[train]),len(X[test])    #训练集有84个,测试集有16个 
#  print "++",probas_           #predict_proba()函数输出的是测试集在lael各类别上的置信度, 
#  #在哪个类别上的置信度高,则分为哪类 
  # Compute ROC curve and area the curve 
  #通过roc_curve()函数,求出fpr和tpr,以及阈值 
  fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1]) 
  mean_tpr += interp(mean_fpr, fpr, tpr)     #对mean_tpr在mean_fpr处进行插值,通过scipy包调用interp()函数 
  mean_tpr[0] = 0.0                #初始处为0 
  roc_auc = auc(fpr, tpr) 
  #画图,只需要plt.plot(fpr,tpr),变量roc_auc只是记录auc的值,通过auc()函数能计算出来 
  plt.plot(fpr, tpr, lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc)) 
 
#画对角线 
plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck') 
 
mean_tpr /= len(cv)           #在mean_fpr100个点,每个点处插值插值多次取平均 
mean_tpr[-1] = 1.0           #坐标最后一个点为(1,1) 
mean_auc = auc(mean_fpr, mean_tpr)   #计算平均AUC值 
#画平均ROC曲线 
#print mean_fpr,len(mean_fpr) 
#print mean_tpr 
plt.plot(mean_fpr, mean_tpr, 'k--', 
     label='Mean ROC (area = %0.2f)' % mean_auc, lw=2) 
 
plt.xlim([-0.05, 1.05]) 
plt.ylim([-0.05, 1.05]) 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive Rate') 
plt.title('Receiver operating characteristic example') 
plt.legend(loc="lower right") 
plt.show()

补充知识:批量进行One-hot-encoder且进行特征字段拼接,并完成模型训练demo

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}
import org.apache.spark.ml.feature.VectorAssembler
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoostClassificationModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.PipelineModel
 
val data = (spark.read.format("csv")
 .option("sep", ",")
 .option("inferSchema", "true")
 .option("header", "true")
 .load("/Affairs.csv"))
 
data.createOrReplaceTempView("res1")
val affairs = "case when affairs>0 then 1 else 0 end as affairs,"
val df = (spark.sql("select " + affairs +
 "gender,age,yearsmarried,children,religiousness,education,occupation,rating" +
 " from res1 "))
 
val categoricals = df.dtypes.filter(_._2 == "StringType") map (_._1)
val indexers = categoricals.map(
 c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx")
)
 
val encoders = categoricals.map(
 c => new OneHotEncoder().setInputCol(s"${c}_idx").setOutputCol(s"${c}_enc").setDropLast(false)
)
 
val colArray_enc = categoricals.map(x => x + "_enc")
val colArray_numeric = df.dtypes.filter(_._2 != "StringType") map (_._1)
val final_colArray = (colArray_numeric ++ colArray_enc).filter(!_.contains("affairs"))
val vectorAssembler = new VectorAssembler().setInputCols(final_colArray).setOutputCol("features")
 
/*
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler))
pipeline.fit(df).transform(df)
*/
 
///
// Create an XGBoost Classifier 
val xgb = new XGBoostEstimator(Map("num_class" -> 2, "num_rounds" -> 5, "objective" -> "binary:logistic", "booster" -> "gbtree")).setLabelCol("affairs").setFeaturesCol("features")
 
// XGBoost paramater grid
val xgbParamGrid = (new ParamGridBuilder()
  .addGrid(xgb.round, Array(10))
  .addGrid(xgb.maxDepth, Array(10,20))
  .addGrid(xgb.minChildWeight, Array(0.1))
  .addGrid(xgb.gamma, Array(0.1))
  .addGrid(xgb.subSample, Array(0.8))
  .addGrid(xgb.colSampleByTree, Array(0.90))
  .addGrid(xgb.alpha, Array(0.0))
  .addGrid(xgb.lambda, Array(0.6))
  .addGrid(xgb.scalePosWeight, Array(0.1))
  .addGrid(xgb.eta, Array(0.4))
  .addGrid(xgb.boosterType, Array("gbtree"))
  .addGrid(xgb.objective, Array("binary:logistic")) 
  .build())
 
// Create the XGBoost pipeline
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler, xgb))
 
// Setup the binary classifier evaluator
val evaluator = (new BinaryClassificationEvaluator()
  .setLabelCol("affairs")
  .setRawPredictionCol("prediction")
  .setMetricName("areaUnderROC"))
 
// Create the Cross Validation pipeline, using XGBoost as the estimator, the
// Binary Classification evaluator, and xgbParamGrid for hyperparameters
val cv = (new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(xgbParamGrid)
  .setNumFolds(3)
  .setSeed(0))
 
 // Create the model by fitting the training data
val xgbModel = cv.fit(df)
 
 // Test the data by scoring the model
val results = xgbModel.transform(df)
 
// Print out a copy of the parameters used by XGBoost, attention pipeline
(xgbModel.bestModel.asInstanceOf[PipelineModel]
 .stages(5).asInstanceOf[XGBoostClassificationModel]
 .extractParamMap().toSeq.foreach(println))
results.select("affairs","prediction").show
 
println("---Confusion Matrix------")
results.stat.crosstab("affairs","prediction").show()
 
// What was the overall accuracy of the model, using AUC
val auc = evaluator.evaluate(results)
println("----AUC--------")
println("auc="+auc)

以上这篇利用scikitlearn画ROC曲线实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python下线程之间的共享和释放示例
May 04 Python
Python中用于返回绝对值的abs()方法
May 14 Python
详解Python编程中基本的数学计算使用
Feb 04 Python
Django实现分页功能
Jul 02 Python
django数据关系一对多、多对多模型、自关联的建立
Jul 24 Python
pandas中遍历dataframe的每一个元素的实现
Oct 23 Python
python不使用for计算两组、多个矩形两两间的iou方式
Jan 18 Python
python下载卫星云图合成gif的方法示例
Feb 18 Python
Python3 assert断言实现原理解析
Mar 02 Python
python实现手势识别的示例(入门)
Apr 15 Python
tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
Jun 30 Python
Python 中的 copy()和deepcopy()
Nov 07 Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 #Python
keras用auc做metrics以及早停实例
Jul 02 #Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
You might like
PHP实现的多维数组去重操作示例
2018/07/21 PHP
JS拖动技术 关于setCapture使用
2010/12/09 Javascript
jQuery控制文本框只能输入数字和字母及使用方法
2016/05/26 Javascript
Vuejs第八篇之Vuejs组件的定义实例解析
2016/09/05 Javascript
微信小程序 HTTPS报错整理常见问题及解决方案
2016/12/14 Javascript
基于jQuery制作小图标上下滑动特效
2017/01/18 Javascript
原生JS实现跑马灯效果
2017/02/20 Javascript
nodejs6下使用koa2框架实例
2017/05/18 NodeJs
vue结合axios与后端进行ajax交互的方法
2018/07/06 Javascript
详解关于html,css,js三者的加载顺序问题
2019/04/10 Javascript
VUEX-action可以修改state吗
2019/11/19 Javascript
[01:07:13]TNC vs Pain 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
浅谈Python对内存的使用(深浅拷贝)
2018/01/17 Python
使用Python爬了4400条淘宝商品数据,竟发现了这些“潜规则”
2018/03/23 Python
基于python 处理中文路径的终极解决方法
2018/04/12 Python
Django之模型层多表操作的实现
2019/01/08 Python
解决PyCharm控制台输出乱码的问题
2019/01/16 Python
基于PyQt4和PySide实现输入对话框效果
2019/02/27 Python
Numpy数组array和矩阵matrix转换方法
2019/08/05 Python
Python 获取项目根路径的代码
2019/09/27 Python
python实现的多任务版udp聊天器功能案例
2019/11/13 Python
tensorflow 保存模型和取出中间权重例子
2020/01/24 Python
Python 防止死锁的方法
2020/07/29 Python
python matplotlib库的基本使用
2020/09/23 Python
pymysql模块使用简介与示例
2020/11/17 Python
葡萄牙鞋子品牌:Fair
2016/12/10 全球购物
临床医学大学生求职信
2013/09/28 职场文书
《挑山工》的教学反思
2014/02/16 职场文书
校庆标语集锦
2014/06/25 职场文书
2014年宣传工作总结
2014/11/18 职场文书
3.15消费者权益日活动总结
2015/02/09 职场文书
工伤劳动仲裁代理词
2015/05/25 职场文书
大国崛起英国观后感
2015/06/02 职场文书
Django开发RESTful API实现增删改查(入门级)
2021/05/10 Python
Redis如何使用乐观锁(CAS)保证数据一致性
2022/03/25 Redis
Java异常体系非正常停止和分类
2022/06/14 Java/Android