keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本设置超时机制系统时间的方法
Feb 21 Python
Python如何实现文本转语音
Aug 08 Python
python妙用之编码的转换详解
Apr 21 Python
python的mysqldb安装步骤详解
Aug 14 Python
Python实现小数转化为百分数的格式化输出方法示例
Sep 20 Python
Python爬虫实现百度图片自动下载
Feb 04 Python
Python Django 命名空间模式的实现
Aug 09 Python
基于Python2、Python3中reload()的不同用法介绍
Aug 12 Python
用Python批量把文件复制到另一个文件夹的实现方法
Aug 16 Python
python中删除某个元素的方法解析
Nov 05 Python
virtualenv介绍及简明教程
Jun 23 Python
windows支持哪个版本的python
Jul 03 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
简单实用的网站PHP缓存类实例
2014/07/18 PHP
浅析PHP文件下载原理
2014/12/25 PHP
微信支付开发发货通知实例
2016/07/12 PHP
PHP利用Socket获取网站的SSL证书与公钥
2017/06/18 PHP
js select常用操作控制代码
2010/03/16 Javascript
浅析jquery的作用与优势
2013/12/02 Javascript
jQuery中clearQueue()方法用法实例
2014/12/29 Javascript
AngularJS进行性能调优的7个建议
2015/12/28 Javascript
JS 面向对象之继承---多种组合继承详解
2016/07/10 Javascript
JavaScript简单获取系统当前时间完整示例
2016/08/02 Javascript
微信小程序  Mustache语法详细介绍
2016/10/27 Javascript
vue2.0嵌套路由实现豆瓣电影分页功能(附demo)
2017/03/13 Javascript
javascript 取小数点后几位几种方法总结
2017/08/02 Javascript
Nodejs中crypto模块的安全知识讲解
2018/01/03 NodeJs
p5.js实现斐波那契螺旋的示例代码
2018/03/22 Javascript
pm2发布node配置文件ecosystem.json详解
2019/05/15 Javascript
three.js 制作动态二维码的示例代码
2020/07/31 Javascript
[37:29]完美世界DOTA2联赛PWL S2 LBZS vs Forest 第二场 11.19
2020/11/19 DOTA
将string类型的数据类型转换为spark rdd时报错的解决方法
2019/02/18 Python
在macOS上搭建python环境的实现方法
2019/08/13 Python
使用 Python 清理收藏夹里已失效的网站
2019/12/03 Python
Python datetime 如何处理时区信息
2020/09/02 Python
python与js主要区别点总结
2020/09/13 Python
scrapy redis配置文件setting参数详解
2020/11/18 Python
伦敦最受欢迎的蛋糕店:Konditor & Cook
2019/11/01 全球购物
家乐福台湾线上购物网:Carrefour台湾
2020/09/15 全球购物
卫校护理专业毕业生求职信
2013/11/26 职场文书
物理教学随笔感言
2014/02/22 职场文书
2014学雷锋活动总结
2014/03/09 职场文书
乡镇消防工作实施方案
2014/03/27 职场文书
班主任与学生安全责任书
2014/07/25 职场文书
学习党的群众路线教育实践活动剖析材料
2014/10/13 职场文书
九寨沟导游词
2015/02/02 职场文书
2015年安全教育月活动总结
2015/03/26 职场文书
运动会跳远广播稿
2015/08/19 职场文书
python geopandas读取、创建shapefile文件的方法
2021/06/29 Python