keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的Numpy入门教程
Apr 26 Python
Python的条件语句与运算符优先级详解
Oct 13 Python
定制FileField中的上传文件名称实例
Aug 23 Python
Python numpy实现二维数组和一维数组拼接的方法
Jun 05 Python
详解TensorFlow查看ckpt中变量的几种方法
Jun 19 Python
对python借助百度云API对评论进行观点抽取的方法详解
Feb 21 Python
python实现键盘输入的实操方法
Jul 16 Python
基于梯度爆炸的解决方法:clip gradient
Feb 04 Python
使用python批量转换文件编码为UTF-8的实现
Apr 03 Python
Mac PyCharm中的.gitignore 安装设置教程
Apr 16 Python
Ubuntu中配置TensorFlow使用环境的方法
Apr 21 Python
Python测试框架:pytest学习笔记
Oct 20 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
关于Iframe如何跨域访问Cookie和Session的解决方法
2013/04/15 PHP
json的键名为数字时的调用方式(示例代码)
2013/11/15 PHP
php使用include 和require引入文件的区别
2017/02/16 PHP
php插件Xajax使用方法详解
2017/08/31 PHP
PHP验证类的封装与使用方法详解
2019/01/10 PHP
为超链接加上disabled后的故事
2010/12/10 Javascript
js对象之JS入门之Array对象操作小结
2011/01/09 Javascript
JavaScript表格常用操作方法汇总
2015/04/15 Javascript
JS对字符串编码的几种方式使用指南
2015/05/14 Javascript
SWFObject基本用法实例分析
2015/07/20 Javascript
JS实现先显示大图后自动收起显示小图的广告代码
2015/09/04 Javascript
jquery获取所有选中的checkbox实现代码
2016/05/26 Javascript
JS原生带小白点轮播图实例讲解
2017/07/22 Javascript
浅谈vue的iview列表table render函数设置DOM属性值的方法
2017/09/30 Javascript
详解一个基于套接字实现长连接的express
2019/03/28 Javascript
axios实现文件上传并获取进度
2020/03/25 Javascript
关于vue路由缓存清除在main.js中的设置
2019/11/06 Javascript
基于Ionic3实现选项卡切换并重新加载echarts
2020/09/24 Javascript
[42:25]2018DOTA2亚洲邀请赛 4.5 淘汰赛 LGD vs Liquid 第三场
2018/04/06 DOTA
python实现unicode转中文及转换默认编码的方法
2017/04/29 Python
Python的语言类型(详解)
2017/06/24 Python
用Python和WordCloud绘制词云的实现方法(内附让字体清晰的秘笈)
2019/01/08 Python
使用python实现mqtt的发布和订阅
2019/05/05 Python
浅谈PyQt5 的帮助文档查找方法,可以查看每个类的方法
2019/06/25 Python
Python实现微信小程序支付功能
2019/07/25 Python
解析python的局部变量和全局变量
2019/08/15 Python
Html5页面获取微信公众号的openid的方法
2020/05/12 HTML / CSS
马来西亚网上美容店:Hermo.my
2017/11/25 全球购物
ktv总经理岗位职责
2014/02/17 职场文书
2015年财务科工作总结范文
2015/05/13 职场文书
纪律委员竞选稿
2015/11/19 职场文书
2016大学生社会实践单位评语
2015/12/01 职场文书
《司马光》教学反思
2016/02/22 职场文书
初中思想品德教学反思
2016/02/24 职场文书
Java方法重载和方法重写的区别到底在哪?
2021/06/11 Java/Android
详解Redis在SpringBoot工程中的综合应用
2021/10/16 Redis