keras中模型训练class_weight,sample_weight区别说明


Posted in Python onMay 23, 2020

keras 中fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0,

validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0,

steps_per_epoch=None, validation_steps=None)

官方文档中:

class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。

sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'。

class_weight---主要针对的上数据不均衡问题,比如:异常检测的二项分类问题,异常数据仅占1%,正常数据占99%; 此时就要设置不同类对loss的影响。

sample_weigh---主要解决的是样本质量不同的问题,比如前1000个样本的可信度,那么它的权重就要高,后1000个样本可能有错、不可信,那么权重就要调低。

补充知识:Keras 中数据不均衡时,metrics,class_weight的设置方法

当数据处理不均衡时,比如处理癌症训练问题,有病样本很少,参考:

http://www.deepideas.net/unbalanced-classes-machine-learning/

主要从两个方面着手:

一、loss函数的权重问题

训练时,设置的权重:

class_weight={
  1: n_non_cancer_samples / n_cancer_samples * t
}

二、编译时设置模型的metrics

def sensitivity(y_true, y_pred):
  true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
  return true_positives / (possible_positives + K.epsilon())

def specificity(y_true, y_pred):
  true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
  possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
  return true_negatives / (possible_negatives + K.epsilon())
model.compile(
  loss='binary_crossentropy',
  optimizer=RMSprop(0.001),
  metrics=[sensitivity, specificity]
)

以上这篇keras中模型训练class_weight,sample_weight区别说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 快速排序代码
Nov 23 Python
python实现在控制台输入密码不显示的方法
Jul 02 Python
基于python的Tkinter编写登陆注册界面
Jun 30 Python
在python3环境下的Django中使用MySQL数据库的实例
Aug 29 Python
Django实现快速分页的方法实例
Oct 22 Python
Python变量赋值的秘密分享
Apr 03 Python
Django组件之cookie与session的使用方法
Jan 10 Python
python的等深分箱实例
Nov 22 Python
TensorFlow加载模型时出错的解决方式
Feb 06 Python
python数据类型可变不可变知识点总结
Mar 06 Python
pyqt5数据库使用详细教程(打包解决方案)
Mar 25 Python
Python OpenCV实现图像模板匹配详解
Apr 07 Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
Keras实现将两个模型连接到一起
May 23 #Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
You might like
php 前一天或后一天的日期
2008/06/28 PHP
ThinkPHP入库出现两次反斜线转义及数据库类转义的解决方法
2014/11/04 PHP
php中http与https跨域共享session的解决方法
2014/12/20 PHP
PHP实现长文章分页实例代码(附源码)
2016/02/03 PHP
常用简易JavaScript函数
2009/04/09 Javascript
基于mootools 1.3框架下的图片滑动效果代码
2011/04/22 Javascript
jQuery 数据缓存模块进化史详细介绍
2012/11/19 Javascript
js获取某月的最后一天日期的简单实例
2013/06/22 Javascript
jquery+php实现搜索框自动提示
2014/11/28 Javascript
jquery easyui validatebox remote的使用详解
2016/11/09 Javascript
jQuery实现弹窗居中效果类似alert()
2017/02/27 Javascript
ES6学习笔记之Set和Map数据结构详解
2017/04/07 Javascript
详解webpack打包vue时提取css
2017/05/26 Javascript
nodejs构建本地web测试服务器 如何解决访问静态资源问题
2017/07/14 NodeJs
iscroll.js滚动加载实例详解
2017/07/18 Javascript
javascript高级模块化require.js的具体使用方法
2017/10/31 Javascript
vue与iframe之间的信息交互的实现
2020/04/08 Javascript
es5 类与es6中class的区别小结
2020/11/09 Javascript
[46:00]Ti4 冒泡赛第二轮LGD vs C9 2
2014/07/14 DOTA
[51:10]VP vs VGJ.S 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
用python代码做configure文件
2014/07/20 Python
仅用500行Python代码实现一个英文解析器的教程
2015/04/02 Python
Python使用Scrapy保存控制台信息到文本解析
2017/12/27 Python
Python设计模式之抽象工厂模式原理与用法详解
2019/01/15 Python
python实现统计文本中单词出现的频率详解
2019/05/20 Python
详解使用Python下载文件的几种方法
2019/10/13 Python
python区分不同数据类型的方法
2019/10/14 Python
python为Django项目上的每个应用程序创建不同的自定义404页面(最佳答案)
2020/03/09 Python
德国电子商城:ComputerUniverse
2017/04/21 全球购物
捷克鲜花配送:Florea.cz
2018/10/29 全球购物
美国专业消费电子及摄影器材网站:B&H Photo Video
2019/12/18 全球购物
奥巴马竞选演讲稿
2014/05/15 职场文书
李敖北大演讲稿
2014/05/24 职场文书
2015年车间安全管理工作总结
2015/05/13 职场文书
《家庭教育》读后感3篇
2019/12/18 职场文书
详解CSS故障艺术
2021/05/25 HTML / CSS