Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常用小技巧总结
Jun 01 Python
python 网络爬虫初级实现代码
Feb 27 Python
Python连接MySQL并使用fetchall()方法过滤特殊字符
Mar 13 Python
Python 详解基本语法_函数_返回值
Jan 22 Python
使用python实现接口的方法
Jul 07 Python
python3.6+django2.0开发一套学员管理系统
Mar 03 Python
Python + selenium自动化环境搭建的完整步骤
May 19 Python
python生成n个元素的全组合方法
Nov 13 Python
python中的subprocess.Popen()使用详解
Dec 25 Python
django在保存图像的同时压缩图像示例代码详解
Feb 11 Python
学习Python列表的基础知识汇总
Mar 10 Python
python用海龟绘图写贪吃蛇游戏
Jun 18 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
B2K与车机的中波PK
2021/03/02 无线电
MySQL的FIND_IN_SET函数使用方法分享
2012/03/27 PHP
Zend Framework常用校验器详解
2016/12/09 PHP
PHP微信公众号开发之微信红包实现方法分析
2017/07/14 PHP
javascript编程起步(第五课)
2007/01/10 Javascript
JavaScript之Getters和Setters 平台支持等详细介绍
2012/12/07 Javascript
jquery.validate的使用说明介绍
2013/11/12 Javascript
IE中图片的onload事件无效问题和解决方法
2014/06/06 Javascript
javascript使用数组的push方法完成快速排序
2014/09/15 Javascript
JS使用ajax方法获取指定url的head信息中指定字段值的方法
2015/03/24 Javascript
JS实现六边形3D拖拽翻转效果的方法
2016/09/11 Javascript
JavaScript的兼容性与调试技巧
2016/11/22 Javascript
bootstrap下拉菜单使用方法解析
2017/01/13 Javascript
js实现下拉菜单效果
2017/03/01 Javascript
jQuery实现的隔行变色功能【案例】
2019/02/18 jQuery
jQuery each和js forEach用法比较
2019/02/27 jQuery
解决cordova+vue 项目打包成APK应用遇到的问题
2019/05/10 Javascript
在JavaScript中使用严格模式(Strict Mode)
2019/06/13 Javascript
Nuxt.js实战和配置详解
2019/08/05 Javascript
小谈angular ng deploy的实现
2020/04/07 Javascript
javascript设计模式 ? 装饰模式原理与应用实例分析
2020/04/14 Javascript
解决父组件将子组件作为弹窗调用只执行一次created的问题
2020/07/24 Javascript
详解设计模式中的工厂方法模式在Python程序中的运用
2016/03/02 Python
Python 实现选择排序的算法步骤
2018/04/22 Python
pip install urllib2不能安装的解决方法
2018/06/12 Python
Python实现的远程文件自动打包并下载功能示例
2019/07/12 Python
python目标检测给图画框,bbox画到图上并保存案例
2020/03/10 Python
python获取linux系统信息的三种方法
2020/10/14 Python
匡威德国官网:Converse德国
2019/01/26 全球购物
QA工程师岗位职责
2013/11/20 职场文书
安全大检查反思材料
2014/01/31 职场文书
白酒营销策划方案
2014/08/17 职场文书
四川省传达学习贯彻党的群众路线教育实践活动总结大会精神新闻稿
2014/10/26 职场文书
2014年酒店工作总结范文
2014/11/17 职场文书
关于CSS浮动与取消浮动的问题
2021/06/28 HTML / CSS
HTML中的表单元素介绍
2022/02/28 HTML / CSS