keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作mysql数据库
Mar 05 Python
关于Python 3中print函数的换行详解
Aug 08 Python
Python排序算法之选择排序定义与用法示例
Apr 29 Python
Python numpy.array()生成相同元素数组的示例
Nov 12 Python
Python中垃圾回收和del语句详解
Nov 15 Python
Python multiprocess pool模块报错pickling error问题解决方法分析
Mar 20 Python
python画图--输出指定像素点的颜色值方法
Jul 03 Python
django创建超级用户过程解析
Sep 18 Python
Python 批量读取文件中指定字符的实现
Mar 06 Python
python 实现逻辑回归
Dec 30 Python
python爬虫今日热榜数据到txt文件的源码
Feb 23 Python
Python torch.flatten()函数案例详解
Aug 30 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
在WINDOWS中设置计划任务执行PHP文件的方法
2011/12/19 PHP
php中取得文件的后缀名?
2012/02/20 PHP
色色整理的PHP面试题集锦
2012/03/08 PHP
header跳转和include包含问题详解
2012/09/08 PHP
ThinkPHP实现跨模块调用操作方法概述
2014/06/20 PHP
谈谈你对Zend SAPIs(Zend SAPI Internals)的理解
2015/11/10 PHP
yii权限控制的方法(三种方法)
2015/12/28 PHP
php实现获取近几日、月时间示例
2019/07/06 PHP
JavaScript性能陷阱小结(附实例说明)
2010/12/28 Javascript
javascript 闭包
2011/09/15 Javascript
Jquery实现的角色左右选择特效
2014/05/21 Javascript
jQuery级联操作绑定事件实例
2014/09/02 Javascript
js实现选中页面文字将其分享到新浪微博
2015/11/05 Javascript
全面介绍javascript实用技巧及单竖杠
2016/07/18 Javascript
前端面试题及答案整理(二)
2016/08/26 Javascript
Vue2.0组件间数据传递示例
2017/03/07 Javascript
详解Angular之constructor和ngOnInit差异及适用场景
2017/06/22 Javascript
JavaScript学习总结之正则的元字符和一些简单的应用
2017/06/30 Javascript
详解Vue 全局变量,局部变量
2019/04/17 Javascript
如何让Nodejs支持H5 History模式(connect-history-api-fallback源码分析)
2019/05/30 NodeJs
15分钟学会vue项目改造成SSR(小白教程)
2019/12/17 Javascript
解决Vue 给mapState中定义的属性赋值报错的问题
2020/06/22 Javascript
python使用urllib2模块获取gravatar头像实例
2013/12/18 Python
在Python中使用列表生成式的教程
2015/04/27 Python
python 计算数据偏差和峰度的方法
2019/06/29 Python
使用遗传算法求二元函数的最小值
2020/02/11 Python
python实现Pyecharts实现动态地图(Map、Geo)
2020/03/25 Python
python模块如何查看
2020/06/16 Python
Pycharm添加虚拟解释器报错问题解决方案
2020/10/13 Python
海蓝之谜英国官网:La Mer英国
2020/01/15 全球购物
意大利领先的奢侈品在线时装零售商:MCLABELS
2020/10/13 全球购物
新闻专业毕业生英文求职信
2014/03/19 职场文书
公职人员索取回扣检举信
2014/04/04 职场文书
运动会横幅标语
2014/06/17 职场文书
党员教师批评与自我批评发言稿
2014/10/15 职场文书
用Python selenium实现淘宝抢单机器人
2021/06/18 Python