Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的彩票机选器实例
Jun 17 Python
Python编程实现控制cmd命令行显示颜色的方法示例
Aug 14 Python
python+VTK环境搭建及第一个简单程序代码
Dec 13 Python
Python实现基于POS算法的区块链
Aug 07 Python
对Python random模块打乱数组顺序的实例讲解
Nov 08 Python
python tkinter控件布局项目实例
Nov 04 Python
Python常用模块os.path之文件及路径操作方法
Dec 03 Python
利用python画出AUC曲线的实例
Feb 28 Python
Python动态强类型解释型语言原理解析
Mar 25 Python
Python实现FTP文件定时自动下载的步骤
Dec 19 Python
python批量生成身份证号到Excel的两种方法实例
Jan 14 Python
python实现马丁策略回测3000只股票的实例代码
Jan 22 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
dede3.1分页文字采集过滤规则详说(图文教程)
2007/04/03 PHP
Cannot modify header information错误解决方法
2008/10/08 PHP
PHP设计模式之装饰者模式
2012/02/29 PHP
与文件上传有关的php配置参数总结
2013/06/14 PHP
最新优化收藏到网摘代码(digg,diigo)
2007/02/07 Javascript
Document 对象的常用方法
2009/07/31 Javascript
《JavaScript高级程序设计》阅读笔记(二) ECMAScript中的原始类型
2012/02/27 Javascript
jquery中get和post的简单实例
2014/02/04 Javascript
jQuery 无限级菜单的简单实例
2014/02/21 Javascript
jquery的ajax异步请求接收返回json数据实例
2014/06/16 Javascript
node.js中使用socket.io的方法
2014/12/15 Javascript
js正则匹配出所有图片及图片地址src的方法
2015/06/08 Javascript
javascript中一些util方法汇总
2015/06/10 Javascript
基于bootstrap3和jquery的分页插件
2015/07/31 Javascript
分享一个原生的JavaScript拖动方法
2016/09/25 Javascript
jQuery中 bind的用法简单介绍
2017/02/13 Javascript
Angular中的interceptors拦截器
2017/06/25 Javascript
vue.js项目打包上线的图文教程
2017/11/16 Javascript
vue中post请求以a=a&b=b 的格式写遇到的问题
2018/04/27 Javascript
jQuery md5加密插件jQuery.md5.js用法示例
2018/08/24 jQuery
Three.js中矩阵和向量的使用教程
2019/03/19 Javascript
微信小程序中使用Async-await方法异步请求变为同步请求方法
2019/03/28 Javascript
layui监听单元格编辑前后交互的例子
2019/09/16 Javascript
Vue项目环境搭建详细总结
2019/09/26 Javascript
使用element-ui +Vue 解决 table 里包含表单验证的问题
2020/07/17 Javascript
js实现幻灯片轮播图
2020/08/14 Javascript
windows下python安装paramiko模块和pycrypto模块(简单三步)
2017/07/06 Python
Python简单基础小程序的实例代码
2019/04/28 Python
浅谈Python中threading join和setDaemon用法及区别说明
2020/05/02 Python
button在IE6/7下的黑边去除方案
2012/12/24 HTML / CSS
html2canvas把div保存图片高清图的方法示例
2018/03/05 HTML / CSS
喜诗官方在线巧克力店:See’s Candies
2017/01/01 全球购物
乌克兰第一的珠宝网上商店:Gold.ua
2019/11/29 全球购物
如何用PHP实现邮件发送
2012/12/26 面试题
安全承诺书
2015/01/19 职场文书
vue实现简单数据双向绑定
2021/04/28 Vue.js