keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多进程编程技术实例分析
Sep 16 Python
Python验证码识别的方法
Jul 10 Python
Python处理JSON时的值报错及编码报错的两则解决实录
Jun 26 Python
Python实现的手机号归属地相关信息查询功能示例
Jun 08 Python
Python实现决策树C4.5算法的示例
May 30 Python
浅谈python中对于json写入txt文件的编码问题
Jun 07 Python
Python3爬虫学习之MySQL数据库存储爬取的信息详解
Dec 12 Python
python控制nao机器人身体动作实例详解
Apr 29 Python
pywinauto自动化操作记事本
Aug 26 Python
Python3实现配置文件差异对比脚本
Nov 18 Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 Python
python标准库os库的函数介绍
Feb 12 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
详解WordPress开发中get_header()获取头部函数的用法
2016/01/08 PHP
100行PHP代码实现socks5代理服务器
2016/04/28 PHP
vmware linux系统安装最新的php7图解
2019/04/14 PHP
PHP实现本地图片转base64格式并上传
2020/05/29 PHP
JavaScript 空位补零实现代码
2010/02/26 Javascript
javasctipt如何显示几分钟前、几天前等
2014/04/30 Javascript
JS实现超简洁网页title标题跑动闪烁提示效果代码
2015/10/23 Javascript
JS弹出窗口的运用与技巧大全
2016/11/01 Javascript
bootstrap配合Masonry插件实现瀑布式布局
2017/01/18 Javascript
JS ES6中setTimeout函数的执行上下文示例
2017/04/27 Javascript
socket.io学习教程之基础介绍(一)
2017/04/29 Javascript
JS中使用gulp实现压缩文件及浏览器热加载功能
2017/07/12 Javascript
JavaScript选择排序算法原理与实现方法示例
2018/08/06 Javascript
Vue2 监听属性改变watch的实例代码
2018/08/27 Javascript
使用Node.js实现一个多人游戏服务器引擎
2019/03/13 Javascript
JS实现水平遍历和嵌套递归操作示例
2019/08/15 Javascript
JS实现滚动条触底加载更多
2019/09/19 Javascript
解决vue 表格table列求和的问题
2019/11/06 Javascript
微信浏览器下拉黑边解决方案 wScroollFix
2020/01/21 Javascript
JavaScript实现与web通信的方法详解
2020/08/07 Javascript
Vue 使用iframe引用html页面实现vue和html页面方法的调用操作
2020/11/16 Javascript
用Python编写生成树状结构的文件目录的脚本的教程
2015/05/04 Python
python使用matplotlib画饼状图
2018/09/25 Python
python 弹窗提示警告框MessageBox的实例
2019/06/18 Python
Python 控制终端输出文字的实例
2019/07/12 Python
Python彻底删除文件夹及其子文件方式
2019/12/23 Python
python 项目目录结构设置
2020/02/14 Python
Python如何将图像音视频等资源文件隐藏在代码中(小技巧)
2020/02/16 Python
pandas之分组groupby()的使用整理与总结
2020/06/18 Python
amaze ui 的使用详细教程
2020/08/19 HTML / CSS
我看到了用指针调用函数的不同语法形式
2014/07/16 面试题
中医临床专业自我鉴定范文
2014/01/15 职场文书
美德好少年主要事迹
2014/01/29 职场文书
灰雀教学反思
2014/04/28 职场文书
社团活动总结格式
2014/08/29 职场文书
搞笑欢迎词大全
2015/09/30 职场文书