Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的Django中django-userena组件的简单使用教程
May 30 Python
Python使用正则表达式实现文本替换的方法
Apr 18 Python
opencv python 2D直方图的示例代码
Jul 20 Python
Python Numpy 实现交换两行和两列的方法
Jun 26 Python
深入了解Django View(视图系统)
Jul 23 Python
Python生命游戏实现原理及过程解析(附源代码)
Aug 01 Python
python3 requests库实现多图片爬取教程
Dec 18 Python
简单了解python字符串前面加r,u的含义
Dec 26 Python
Tensorflow不支持AVX2指令集的解决方法
Feb 03 Python
pyqt5 QlistView列表显示的实现示例
Mar 24 Python
基于Python的身份证验证识别和数据处理详解
Nov 14 Python
python空元组在all中返回结果详解
Dec 15 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
php5.2.0内存管理改进
2007/01/22 PHP
php页面缓存ob系列函数介绍
2012/10/18 PHP
PHP中常用的转义函数
2014/02/28 PHP
CI框架给视图添加动态数据
2014/12/01 PHP
yii2框架中使用下拉菜单的自动搜索yii-widget-select2实例分析
2016/01/09 PHP
PHP中调用C/C++制作的动态链接库的教程
2016/03/10 PHP
关于PHP中interface的用处详解
2020/07/26 PHP
jquery处理页面弹出层查询数据等待操作实例
2015/03/25 Javascript
JavaScript实现强制重定向至HTTPS页面
2015/06/10 Javascript
Bootstrap每天必学之按钮(Button)插件
2016/04/25 Javascript
利用Angular+Angular-Ui实现分页(代码加简单)
2017/03/10 Javascript
Vue.js上传图片到阿里云OSS存储的方法示例
2018/12/13 Javascript
Antd的Table组件嵌套Table以及选择框联动操作
2020/10/24 Javascript
[03:57]2016完美“圣”典风云人物:rOtk专访
2016/12/09 DOTA
Python实现一个简单的MySQL类
2015/01/07 Python
Python环境下搭建属于自己的pip源的教程
2016/05/05 Python
Python有序字典简单实现方法示例
2017/09/28 Python
Python实现的括号匹配判断功能示例
2018/08/25 Python
django自定义模板标签过程解析
2019/12/14 Python
tensorflow 获取所有variable或tensor的name示例
2020/01/04 Python
Python基础之列表常见操作经典实例详解
2020/02/26 Python
Python使用扩展库pywin32实现批量文档打印实例
2020/04/09 Python
Pytorch 使用CNN图像分类的实现
2020/06/16 Python
使用OpenCV对车道进行实时检测的实现示例代码
2020/06/19 Python
美国知名奢侈美容品牌零售商:Cos Bar
2017/04/21 全球购物
Vuori官网:运动服装的终级表现
2021/01/27 全球购物
PHP如何自定义函数
2016/09/16 面试题
金额转换,阿拉伯数字的金额转换成中国传统的形式如:(¥1011)-> (一千零一拾一元整)输出
2015/05/29 面试题
税务专业毕业生自荐信
2013/11/10 职场文书
亲子读书活动方案
2014/02/22 职场文书
效能监察建议书
2014/05/19 职场文书
党支部组织生活会整改方案
2014/09/30 职场文书
研究生毕业论文导师评语
2014/12/31 职场文书
《赵州桥》教学反思
2016/02/17 职场文书
《所见》教学反思
2016/02/23 职场文书
JavaScript架构localStorage特殊场景下二次封装操作
2022/06/21 Javascript