Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python制作爬虫爬取京东商品评论教程
Dec 16 Python
spyder常用快捷键(分享)
Jul 19 Python
Python实现文件内容批量追加的方法示例
Aug 29 Python
Python简单计算文件MD5值的方法示例
Apr 11 Python
浅谈pyqt5在QMainWindow中布局的问题
Jun 21 Python
Django中间件基础用法详解
Jul 18 Python
基于python二叉树的构造和打印例子
Aug 09 Python
Django为窗体加上防机器人的验证码功能过程解析
Aug 14 Python
通过Turtle库在Python中绘制一个鼠年福鼠
Feb 03 Python
Python小整数对象池和字符串intern实例解析
Mar 21 Python
如何基于python对接钉钉并获取access_token
Apr 21 Python
python xlwt模块的使用解析
Apr 13 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
用PHP实现小型站点广告管理(修正版)
2006/10/09 PHP
Php做的端口嗅探器--可以指定网站和端口
2006/10/09 PHP
Laravel配合jwt使用的方法实例
2020/10/25 PHP
仿猪八戒网左下角的文字滚动效果
2011/10/28 Javascript
jQuery给动态添加的元素绑定事件的方法
2015/03/09 Javascript
AngularJs  Creating Services详解及示例代码
2016/09/02 Javascript
微信小程序开发之选项卡(窗口底部TabBar)页面切换
2017/04/12 Javascript
JavaScript日期工具类DateUtils定义与用法示例
2018/09/03 Javascript
jQuery实现增删改查
2020/12/22 jQuery
Python中无限元素列表的实现方法
2014/08/18 Python
粗略分析Python中的内存泄漏
2015/04/23 Python
详解python实现线程安全的单例模式
2018/03/05 Python
深入浅析Python2.x和3.x版本的主要区别
2018/11/30 Python
对Python Pexpect 模块的使用说明详解
2019/02/14 Python
Python pandas.DataFrame 找出有空值的行
2019/09/09 Python
ansible动态Inventory主机清单配置遇到的坑
2020/01/19 Python
python3注册全局热键的实现
2020/03/22 Python
Django DRF APIView源码运行流程详解
2020/08/17 Python
各大浏览器 CSS3 和 HTML5 兼容速查表 图文
2010/04/01 HTML / CSS
HTML5中的新元素介绍
2008/10/17 HTML / CSS
英国领先的票务代理商之一:The Ticket Factory
2019/02/09 全球购物
构造器Constructor是否可被override?
2013/08/06 面试题
海南地接欢迎词
2014/01/14 职场文书
销售业务员岗位职责
2014/01/29 职场文书
第一批党的群众路线教育实践活动工作总结
2014/03/03 职场文书
保险公司早会主持词
2014/03/22 职场文书
文明寝室申报材料
2014/05/12 职场文书
出国签证在职证明范本
2014/11/24 职场文书
党员先进事迹材料
2014/12/19 职场文书
归元寺导游词
2015/02/06 职场文书
物业项目经理岗位职责
2015/04/01 职场文书
书法社团活动总结
2015/05/07 职场文书
如何写观后感
2015/06/19 职场文书
2016党员入党决心书
2015/09/22 职场文书
社区志愿者服务心得体会
2016/01/22 职场文书
Python基本数据类型之字符串str
2021/07/21 Python