Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫辅助利器PyQuery模块的安装使用攻略
Apr 24 Python
解决Python print 输出文本显示 gbk 编码错误问题
Jul 13 Python
Python数据结构之栈、队列及二叉树定义与用法浅析
Dec 27 Python
selenium+python自动化测试之页面元素定位
Jan 23 Python
python flask解析json数据不完整的解决方法
May 26 Python
Pycharm保存不能自动同步到远程服务器的解决方法
Jun 27 Python
Python3 解决读取中文文件txt编码的问题
Dec 20 Python
解决springboot yml配置 logging.level 报错问题
Feb 21 Python
python百行代码自制电脑端网速悬浮窗的实现
May 12 Python
Elasticsearch py客户端库安装及使用方法解析
Sep 14 Python
Python timeit模块原理及使用方法
Oct 10 Python
python mock测试的示例
Oct 19 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
discuz Passport 通行证 整合笔记
2008/06/30 PHP
php中自定义函数dump查看数组信息类似var_dump
2014/01/27 PHP
利用JQuery为搜索栏增加tag提示
2009/06/22 Javascript
无阻塞加载脚本分析[全]
2011/01/20 Javascript
33个优秀的jQuery 教程分享(幻灯片、动画菜单)
2011/07/08 Javascript
页面只能打开一次Cooike如何实现
2012/12/04 Javascript
JS刷新框架外页面七种实现代码
2013/02/18 Javascript
调试代码导致IE出错的避免方法
2014/04/04 Javascript
JS实现鼠标经过好友列表中的好友头像时显示资料卡的效果
2014/07/02 Javascript
jQuery实现点击该行即可删除HTML表格行
2014/10/17 Javascript
js动态切换图片的方法
2015/01/20 Javascript
js字符串操作方法实例分析
2015/05/06 Javascript
javascript实现五星评价代码(源码下载)
2015/08/11 Javascript
Bootstrap编写导航栏和登陆框
2016/05/30 Javascript
jQuery实现对无序列表的排序功能(附demo源码下载)
2016/06/25 Javascript
ES6学习笔记之Set和Map数据结构详解
2017/04/07 Javascript
如何优雅的在一台vps(云主机)上面部署vue+mongodb+express项目
2019/01/20 Javascript
node.js基于socket.io快速实现一个实时通讯应用
2019/04/23 Javascript
[02:51]DOTA2英雄基础教程 艾欧
2014/01/13 DOTA
Python实现的简单文件传输服务器和客户端
2015/04/08 Python
详解Python中的Descriptor描述符类
2016/06/14 Python
python实现最长公共子序列
2018/05/22 Python
python实现可视化动态CPU性能监控
2018/06/21 Python
Python函数参数操作详解
2018/08/03 Python
numpy 计算两个数组重复程度的方法
2018/11/07 Python
50行Python代码实现视频中物体颜色识别和跟踪(必须以红色为例)
2019/11/20 Python
css3实现针线缝合效果(图解步骤)
2013/02/04 HTML / CSS
一款基于css3麻将筛子3D翻转特效的实例教程
2014/12/31 HTML / CSS
Tommy Hilfiger美国官网:美国高端休闲领导品牌
2019/01/14 全球购物
利用异或运算实现两个无符号数的加法运算
2013/12/20 面试题
元旦晚会邀请函
2014/01/27 职场文书
个人反四风对照检查材料思想汇报
2014/09/23 职场文书
缓刑人员思想汇报
2014/10/11 职场文书
聘任书的格式及模板
2019/10/28 职场文书
pytorch显存一直变大的解决方案
2021/04/08 Python
SQL Server中锁的用法
2022/05/20 SQL Server