Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python之eval()函数危险性浅析
Jul 03 Python
详解Python中的array数组模块相关使用
Jul 05 Python
Python加密方法小结【md5,base64,sha1】
Jul 13 Python
Python实现动态加载模块、类、函数的方法分析
Jul 18 Python
Python中实现switch功能实例解析
Jan 11 Python
Python爬虫之pandas基本安装与使用方法示例
Aug 08 Python
python pandas消除空值和空格以及 Nan数据替换方法
Oct 30 Python
python解析json串与正则匹配对比方法
Dec 20 Python
selenium+python自动化测试之多窗口切换
Jan 23 Python
python3多线程知识点总结
Sep 26 Python
pytorch实现从本地加载 .pth 格式模型
Feb 14 Python
pandas按条件筛选数据的实现
Feb 20 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
PHP MemCached 高级缓存应用代码
2010/08/05 PHP
php的memcache类分享(memcache队列)
2014/03/26 PHP
php实现上传图片生成缩略图示例
2014/04/13 PHP
制作安全性高的PHP网站的几个实用要点
2014/12/30 PHP
PHP实现的简易版图片相似度比较
2015/01/07 PHP
php实现将base64格式图片保存在指定目录的方法
2016/10/13 PHP
动态样式类封装JS代码
2009/09/02 Javascript
ExtJs中简单的登录界面制作方法
2010/08/19 Javascript
jQuery学习笔记(2)--用jquery实现各种模态提示框代码及项目构架
2013/04/08 Javascript
通过url查找a元素并点击
2014/04/09 Javascript
dreamweaver 8实现Jquery自动提示
2014/12/04 Javascript
Vue.js -- 过滤器使用总结
2017/02/18 Javascript
bootstrap轮播图示例代码分享
2017/05/17 Javascript
angularjs路由传值$routeParams详解
2020/09/05 Javascript
实例分析JS与Node.js中的事件循环
2017/12/12 Javascript
angular2中使用第三方js库的实例
2018/02/26 Javascript
vue2.0 循环遍历加载不同图片的方法
2018/03/06 Javascript
node打造微信个人号机器人的方法示例
2018/04/26 Javascript
[02:23]完美世界全国高校联赛街访DOTA2第一期
2019/11/28 DOTA
Python实现抓取百度搜索结果页的网站标题信息
2015/01/22 Python
用Python进行基础的函数式编程的教程
2015/03/31 Python
Python读取Word(.docx)正文信息的方法
2018/03/15 Python
Windows系统下PhantomJS的安装和基本用法
2018/10/21 Python
selenium+python自动化测试环境搭建步骤
2019/06/03 Python
pytorch多进程加速及代码优化方法
2019/08/19 Python
python 实现兔子生兔子示例
2019/11/21 Python
python+opencv3生成一个自定义纯色图教程
2020/02/19 Python
Python文件夹批处理操作代码实例
2020/07/21 Python
Python编写memcached启动脚本代码实例
2020/08/14 Python
详解CSS3 弹性布局快速入门
2019/06/06 HTML / CSS
Europcar葡萄牙:葡萄牙汽车和货车租赁
2017/10/13 全球购物
销售职业生涯规划范文
2014/03/14 职场文书
升旗仪式演讲稿
2014/05/08 职场文书
个人委托书范文
2015/01/28 职场文书
Nginx反向代理至go-fastdfs案例讲解
2021/08/02 Servers
MySQL数据库中的锁、解锁以及删除事务
2022/05/06 MySQL