keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自然语言编码转换模块codecs介绍
Apr 08 Python
利用Python的Django框架生成PDF文件的教程
Jul 22 Python
Python如何快速实现分布式任务
Jul 06 Python
python实现AES和RSA加解密的方法
Mar 28 Python
pandas.cut具体使用总结
Jun 24 Python
Python高级特性 切片 迭代解析
Aug 23 Python
Python Tkinter模块 GUI 可视化实例
Nov 20 Python
Python3.7黑帽编程之病毒篇(基础篇)
Feb 04 Python
Python读入mnist二进制图像文件并显示实例
Apr 24 Python
Python读取Excel一列并计算所有对象出现次数的方法
Sep 04 Python
python四个坐标点对图片区域最小外接矩形进行裁剪
Jun 04 Python
python ansible自动化运维工具执行流程
Jun 24 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
社区(php&&mysql)四
2006/10/09 PHP
动态生成gif格式的图像要注意?
2006/10/09 PHP
学习php过程中的一些注意点的总结
2013/10/25 PHP
PHP 7.0新增加的特性介绍
2017/06/08 PHP
Prototype使用指南之enumerable.js
2007/01/10 Javascript
setTimeout与setInterval在不同浏览器下的差异
2010/01/24 Javascript
javascript+mapbar实现地图定位
2010/04/09 Javascript
JS中判断null、undefined与NaN的方法
2014/03/26 Javascript
图解js图片轮播效果
2015/12/20 Javascript
浅谈在fetch方法中添加header后遇到的预检请求问题
2017/08/31 Javascript
webpack写jquery插件的环境配置
2017/12/21 jQuery
详谈vue+webpack解决css引用图片打包后找不到资源文件的问题
2018/03/06 Javascript
Angular搜索场景中使用rxjs的操作符处理思路
2018/05/30 Javascript
详解如何解决Vue和vue-template-compiler版本之间的问题
2018/09/17 Javascript
webpack结合express实现自动刷新的方法
2019/05/07 Javascript
js实现跟随鼠标移动的小球
2019/08/26 Javascript
微信分享invalid signature签名错误踩过的坑
2020/04/11 Javascript
python使用邻接矩阵构造图代码示例
2017/11/10 Python
python图像处理模块Pillow的学习详解
2019/10/09 Python
解决Jupyter Notebook开始菜单栏Anaconda下消失的问题
2020/04/13 Python
Django models文件模型变更错误解决
2020/05/11 Python
Python plt 利用subplot 实现在一张画布同时画多张图
2021/02/26 Python
CSS3 透明色 RGBA使用介绍
2013/08/06 HTML / CSS
css3实现冲击波效果的示例代码
2018/01/11 HTML / CSS
菲律宾酒店预订网站:Hotels.com菲律宾
2017/07/12 全球购物
Nº21官方在线商店:numeroventuno.com
2019/09/26 全球购物
Linux如何为某个操作添加别名
2013/03/01 面试题
客户代表自我评价范例
2013/09/24 职场文书
销售文员岗位职责
2013/11/29 职场文书
学生打架检讨书
2014/02/14 职场文书
党员岗位承诺书
2014/03/25 职场文书
优秀党员学习焦裕禄精神思想汇报范文
2014/09/10 职场文书
评先进个人材料
2014/12/29 职场文书
演讲开场白和结束语
2015/05/29 职场文书
公司食堂管理制度
2015/08/05 职场文书
go类型转换及与C的类型转换方式
2021/05/05 Golang