Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入解析Python中的变量和赋值运算符
Oct 12 Python
详解用python实现简单的遗传算法
Jan 02 Python
flask中主动抛出异常及统一异常处理代码示例
Jan 18 Python
python的Crypto模块实现AES加密实例代码
Jan 22 Python
Python实现数据可视化看如何监控你的爬虫状态【推荐】
Aug 10 Python
对python xlrd读取datetime类型数据的方法详解
Dec 26 Python
Python可迭代对象操作示例
May 07 Python
numpy求平均值的维度设定的例子
Aug 24 Python
Python pip 安装与使用(安装、更新、删除)
Oct 06 Python
浅谈Python访问MySQL的正确姿势
Jan 07 Python
Python异常继承关系和自定义异常实现代码实例
Feb 20 Python
Python如何配置环境变量详解
May 18 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
PHP开发负载均衡指南
2010/07/17 PHP
打造超酷的PHP数据饼图效果实现代码
2011/11/23 PHP
PHP随机生成随机个数的字母组合示例
2014/01/14 PHP
培养自己的php编码规范
2015/09/28 PHP
php封装json通信接口详解及实例
2017/03/07 PHP
PHP实现的注册,登录及查询用户资料功能API接口示例
2017/06/06 PHP
JavaScript定时器详解及实例
2013/08/01 Javascript
jquery购物车实时结算特效实现思路
2013/09/23 Javascript
JavaScript使用replace函数替换字符串的方法
2015/04/06 Javascript
javascript this详细介绍
2016/09/19 Javascript
NodeJS、NPM安装配置步骤(windows版本) 以及环境变量详解
2017/05/13 NodeJs
使用node.js对音视频文件加密的实例代码
2017/08/30 Javascript
vue-router懒加载速度缓慢问题及解决方法
2018/11/25 Javascript
js事件触发操作实例分析
2019/06/21 Javascript
Electron vue的使用教程图文详解
2019/07/05 Javascript
Vue + element 实现多选框组并保存已选id集合的示例代码
2020/06/03 Javascript
javascript 数组(list)添加/删除的实现
2020/12/17 Javascript
Python单元测试实例详解
2018/05/25 Python
Python将多个list合并为1个list的方法
2018/06/27 Python
python交换两个变量的值方法
2019/01/12 Python
Python网络编程之使用TCP方式传输文件操作示例
2019/11/01 Python
Python list与NumPy array 区分详解
2019/11/06 Python
python_array[0][0]与array[0,0]的区别详解
2020/02/18 Python
Python格式化输出--%s,%d,%f的代码解析
2020/04/29 Python
django中related_name的用法说明
2020/05/20 Python
浅谈TensorFlow之稀疏张量表示
2020/06/30 Python
详解pyqt5的UI中嵌入matplotlib图形并实时刷新(挖坑和填坑)
2020/08/07 Python
CSS3 网页下拉菜单代码解释 中文翻译
2010/02/27 HTML / CSS
Chi Chi London官网:购买连衣裙和礼服
2020/10/25 全球购物
化学学院毕业生自荐信范文
2013/12/17 职场文书
小组合作学习反思
2014/02/18 职场文书
儿童生日会策划方案
2014/05/15 职场文书
学习十八大的心得体会
2014/09/01 职场文书
2014年初中班主任工作总结
2014/11/08 职场文书
结婚保证书(卖身契)
2015/02/26 职场文书
会计稽核岗位职责
2015/04/13 职场文书