Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python监控网站运行异常并发送邮件的方法
Mar 13 Python
python动态加载包的方法小结
Apr 18 Python
Java分治归并排序算法实例详解
Dec 12 Python
对numpy和pandas中数组的合并和拆分详解
Apr 11 Python
python:pandas合并csv文件的方法(图书数据集成)
Apr 12 Python
Python使用re模块实现信息筛选的方法
Apr 29 Python
python实现requests发送/上传多个文件的示例
Jun 04 Python
详解django+django-celery+celery的整合实战
Mar 19 Python
Python使用字典实现的简单记事本功能示例
Aug 15 Python
python实现最大优先队列
Aug 29 Python
解决Python数据可视化中文部分显示方块问题
May 16 Python
基于Python实现一个春节倒计时脚本
Jan 22 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
ThinkPHP 连接Oracle数据库的详细教程[全]
2012/07/16 PHP
php上传大文件失败的原因及应对策略
2015/10/20 PHP
解析WordPress中的post_class与get_post_class函数
2016/01/04 PHP
MSN消息提示类
2006/09/05 Javascript
firefox firebug中文入门教程 脚本之家新年特别版
2010/01/02 Javascript
Javascript 入门基础学习
2010/03/10 Javascript
HTML Dom与Css控制方法
2010/10/25 Javascript
js判断浏览器类型为ie6时不执行
2014/06/15 Javascript
在Ubuntu上安装最新版本的Node.js
2014/07/14 Javascript
jQuery中:first选择器用法实例
2014/12/30 Javascript
JS给Textarea文本框添加行号的方法
2015/08/20 Javascript
Nodejs express框架一个工程中同时使用ejs模版和jade模版
2015/12/28 NodeJs
JavaScript实现获取某个元素相邻兄弟节点的prev与next方法
2016/01/25 Javascript
详解js中==与===的区别
2017/01/08 Javascript
vue 解决数组赋值无法渲染在页面的问题
2019/10/28 Javascript
JavaScript实现手机号码 3-4-4格式并控制新增和删除时光标的位置
2020/06/02 Javascript
python中实现定制类的特殊方法总结
2014/09/28 Python
python 根据正则表达式提取指定的内容实例详解
2016/12/04 Python
python逐行读写txt文件的实例讲解
2018/04/03 Python
Python使用分布式锁的代码演示示例
2018/07/30 Python
python检测文件夹变化,并拷贝有更新的文件到对应目录的方法
2018/10/17 Python
Python3.5基础之变量、数据结构、条件和循环语句、break与continue语句实例详解
2019/04/26 Python
python二维码操作:对QRCode和MyQR入门详解
2019/06/24 Python
python开启debug模式的方法
2019/06/27 Python
利用Python实现最小二乘法与梯度下降算法
2021/02/21 Python
Speedo速比涛德国官方网站:世界领先的泳装品牌
2019/08/26 全球购物
大学生军训自我鉴定
2014/02/12 职场文书
铣床操作工岗位职责
2014/06/13 职场文书
学术会议领导致辞
2015/07/29 职场文书
安全教育的主题班会
2015/08/13 职场文书
教师病假条范文
2015/08/17 职场文书
企业廉洁教育心得体会
2016/01/20 职场文书
FFmpeg视频处理入门教程(新手必看)
2022/01/22 杂记
使用Redis实现点赞取消点赞的详细代码
2022/03/20 Redis
Android学习之BottomSheetDialog组件的使用
2022/06/21 Java/Android
GPU服务器的多用户配置方法
2022/07/07 Servers