keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python分割和拼接字符串
Nov 01 Python
跟老齐学Python之不要红头文件(2)
Sep 28 Python
python回溯法实现数组全排列输出实例分析
Mar 17 Python
用Python编写一个简单的俄罗斯方块游戏的教程
Apr 03 Python
python读写ini配置文件方法实例分析
Jun 30 Python
如何准确判断请求是搜索引擎爬虫(蜘蛛)发出的请求
Oct 13 Python
Django添加sitemap的方法示例
Aug 06 Python
python Matplotlib底图中鼠标滑过显示隐藏内容的实例代码
Jul 31 Python
python各类经纬度转换的实例代码
Aug 08 Python
Python 中的 import 机制之实现远程导入模块
Oct 29 Python
PyQt实现计数器的方法示例
Jan 18 Python
Python 中如何使用 virtualenv 管理虚拟环境
Jan 21 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
javascript 文章截取部分无损html显示实现代码
2010/05/04 Javascript
jQuery旋转插件—rotate支持(ie/Firefox/SafariOpera/Chrome)
2013/01/16 Javascript
jquery仿QQ商城带左右按钮控制焦点图片切换滚动效果
2013/06/27 Javascript
js实现无需数据库的县级以上联动行政区域下拉控件
2013/08/14 Javascript
JS控制阿拉伯数字转为中文大写示例代码
2013/09/04 Javascript
使用JavaScript实现Java的List功能(实例讲解)
2013/11/07 Javascript
js实现简单登录功能的实例代码
2013/11/09 Javascript
javascript实现动态侧边栏代码
2014/02/19 Javascript
AngularJS中的Directive自定义一个表格
2016/01/25 Javascript
Bootstrap每天必学之弹出框(Popover)插件
2016/04/25 Javascript
详解AngularJS中的表单验证(推荐)
2016/11/17 Javascript
关于Vue.js 2.0的Vuex 2.0 你需要更新的知识库
2016/11/30 Javascript
详解升级react-router 4 踩坑指南
2017/08/14 Javascript
mockjs,json-server一起搭建前端通用的数据模拟框架教程
2017/12/18 Javascript
记录一篇关于redux-saga的基本使用过程
2018/08/18 Javascript
基于javascript实现移动端轮播图效果
2020/12/21 Javascript
vue如何使用rem适配
2021/02/06 Vue.js
three.js 实现露珠滴落动画效果的示例代码
2021/03/01 Javascript
在Django的session中使用User对象的方法
2015/07/23 Python
浅谈python import引入不同路径下的模块
2017/07/11 Python
浅谈Django REST Framework限速
2017/12/12 Python
对python3 urllib包与http包的使用详解
2018/05/10 Python
python利用pandas将excel文件转换为txt文件的方法
2018/10/23 Python
Python如何使用argparse模块处理命令行参数
2019/12/11 Python
python不同系统中打开方法
2020/06/23 Python
JD Sports芬兰:英国领先的运动鞋和运动服饰零售商
2018/11/16 全球购物
T3官网:头发造型工具
2019/12/26 全球购物
新东网科技Java笔试题
2012/07/13 面试题
考试违纪检讨书
2014/02/02 职场文书
搞笑征婚广告词
2014/03/17 职场文书
乡镇八一建军节活动方案
2014/08/24 职场文书
工作检讨书范文
2015/01/23 职场文书
学校德育工作总结2015
2015/05/11 职场文书
雷锋之歌观后感
2015/06/10 职场文书
校园文化艺术节开幕词
2016/03/04 职场文书
详解使用 CSS prefers-* 规范提升网站的可访问性与健壮性
2021/05/25 HTML / CSS