Keras自定义IOU方式


Posted in Python onJune 10, 2020

我就废话不多说了,大家还是直接看代码吧!

def iou(y_true, y_pred, label: int):
  """
  Return the Intersection over Union (IoU) for a given label.
  Args:
    y_true: the expected y values as a one-hot
    y_pred: the predicted y values as a one-hot or softmax output
    label: the label to return the IoU for
  Returns:
    the IoU for the given label
  """
  # extract the label values using the argmax operator then
  # calculate equality of the predictions and truths to the label
  y_true = K.cast(K.equal(K.argmax(y_true), label), K.floatx())
  y_pred = K.cast(K.equal(K.argmax(y_pred), label), K.floatx())
  # calculate the |intersection| (AND) of the labels
  intersection = K.sum(y_true * y_pred)
  # calculate the |union| (OR) of the labels
  union = K.sum(y_true) + K.sum(y_pred) - intersection
  # avoid divide by zero - if the union is zero, return 1
  # otherwise, return the intersection over union
  return K.switch(K.equal(union, 0), 1.0, intersection / union)
 
def mean_iou(y_true, y_pred):
  """
  Return the Intersection over Union (IoU) score.
  Args:
    y_true: the expected y values as a one-hot
    y_pred: the predicted y values as a one-hot or softmax output
  Returns:
    the scalar IoU value (mean over all labels)
  """
  # get number of labels to calculate IoU for
  num_labels = K.int_shape(y_pred)[-1] - 1
  # initialize a variable to store total IoU in
  mean_iou = K.variable(0)
  
  # iterate over labels to calculate IoU for
  for label in range(num_labels):
    mean_iou = mean_iou + iou(y_true, y_pred, label)
    
  # divide total IoU by number of labels to get mean IoU
  return mean_iou / num_labels

补充知识:keras 自定义评估函数和损失函数loss训练模型后加载模型出现ValueError: Unknown metric function:fbeta_score

keras自定义评估函数

有时候训练模型,现有的评估函数并不足以科学的评估模型的好坏,这时候就需要自定义一些评估函数,比如样本分布不均衡是准确率accuracy评估无法判定一个模型的好坏,这时候需要引入精确度和召回率作为评估标准,不幸的是keras没有这些评估函数。

以下是参考别的文章摘取的两个自定义评估函数

召回率:

def recall(y_true, y_pred):
  true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
  recall = true_positives / (possible_positives + K.epsilon())
  return recall

精确度:

def precision(y_true, y_pred):
  true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
  precision = true_positives / (predicted_positives + K.epsilon())
  return precision

自定义了评估函数,一般在编译模型阶段加入即可:

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy', precision, recall])

自定义了损失函数focal_loss一般也在编译阶段加入:

model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],
metrics=['accuracy',fbeta_score], )

其他的没有特别要注意的点,直接按照原来的思路训练一版模型出来就好了,关键的地方在于加载模型这里,自定义的函数需要特殊的加载方式,不然会出现加载没有自定义函数的问题:ValueError: Unknown loss function:focal_loss

解决方案:

model_name = 'test_calssification_model.h5'
model_dfcw = load_model(model_name,
            custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

注意点:将自定义的损失函数和评估函数都加入到custom_objects里,以上就是在自定义一个损失函数从编译模型阶段到加载模型阶段出现的所有的问题。

以上这篇Keras自定义IOU方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取标准北京时间的方法
Mar 24 Python
Python中运算符"=="和"is"的详解
Oct 08 Python
python3读取MySQL-Front的MYSQL密码
May 03 Python
Python利用Scrapy框架爬取豆瓣电影示例
Jan 17 Python
Pytorch之finetune使用详解
Jan 18 Python
Python chardet库识别编码原理解析
Feb 18 Python
Pycharm中切换pytorch的环境和配置的教程详解
Mar 13 Python
500行python代码实现飞机大战
Apr 24 Python
Python numpy矩阵处理运算工具用法汇总
Jul 13 Python
Python利用Pillow(PIL)库实现验证码图片的全过程
Oct 04 Python
如何使用Python进行PDF图片识别OCR
Jan 22 Python
解决numpy数组互换两行及赋值的问题
Apr 17 Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
keras 获取某层的输入/输出 tensor 尺寸操作
Jun 10 #Python
Python 字典中的所有方法及用法
Jun 10 #Python
在keras 中获取张量 tensor 的维度大小实例
Jun 10 #Python
You might like
php at(@)符号的用法简介
2009/07/11 PHP
几个有用的php字符串过滤,转换函数代码
2012/05/01 PHP
PHP中多线程的两个实现方法
2016/10/14 PHP
js下通过prototype扩展实现indexOf的代码
2010/12/08 Javascript
jquery通过扩展select控件实现支持enter或focus选择的方法
2015/11/19 Javascript
快速学习jQuery插件 jquery.validate.js表单验证插件使用方法
2015/12/01 Javascript
修复jQuery tablesorter无法正确排序的bug(加千分位数字后)
2016/03/30 Javascript
AngularJS基础 ng-switch 指令简单示例
2016/08/03 Javascript
JS使用onerror捕获异常示例
2016/08/03 Javascript
详解vue-cli开发环境跨域问题解决方案
2017/06/06 Javascript
javascript 日期相减-在线教程(附代码)
2017/08/17 Javascript
JS设计模式之命令模式概念与用法分析
2018/02/06 Javascript
JavaScript实现的反序列化json字符串操作示例
2018/07/18 Javascript
详解一个小实例理解js原型和继承
2019/04/24 Javascript
微信小程序表单验证插件WxValidate的二次封装功能(终极版)
2019/09/03 Javascript
js实现数字滚动特效
2019/12/16 Javascript
python用10行代码实现对黄色图片的检测功能
2015/08/10 Python
Python get获取页面cookie代码实例
2018/09/12 Python
Django admin model 汉化显示文字的实现方法
2019/08/12 Python
构建高效的python requests长连接池详解
2020/05/02 Python
Python连接Impala实现步骤解析
2020/08/04 Python
LEGO玩具英国官方商店:LEGO Shop GB
2018/03/27 全球购物
意大利宠物用品购物网站:Bauzaar
2018/09/15 全球购物
Bailey帽子官方商店:Bailey Hats
2018/09/25 全球购物
枚举与#define宏的区别
2014/04/30 面试题
小学三年级数学教学反思
2014/01/31 职场文书
初中毕业生的自我评价
2014/03/03 职场文书
行政专员岗位职责范本
2014/08/26 职场文书
软环境建设心得体会
2014/09/09 职场文书
六一亲子活动感想
2015/08/07 职场文书
2016消防宣传标语口号
2015/12/26 职场文书
如何利用js在两个html窗口间通信
2021/04/27 Javascript
Python面向对象之内置函数相关知识总结
2021/06/24 Python
Nginx 配置 HTTPS的详细过程
2022/05/30 Servers
vue实现在data里引入相对路径
2022/06/05 Vue.js
MySQL示例讲解数据库约束以及表的设计
2022/06/16 MySQL