keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python升级提示Tkinter模块找不到的解决方法
Aug 22 Python
详解Python操作RabbitMQ服务器消息队列的远程结果返回
Jun 30 Python
python中numpy基础学习及进行数组和矢量计算
Feb 12 Python
python中模块的__all__属性详解
Oct 26 Python
python机器学习理论与实战(四)逻辑回归
Jan 19 Python
Python装饰器原理与简单用法实例分析
Apr 29 Python
python进行TCP端口扫描的实现
Dec 21 Python
python面向对象 反射原理解析
Aug 12 Python
Python使用Excel将数据写入多个sheet
May 16 Python
python程序实现BTC(比特币)挖矿的完整代码
Jan 20 Python
PyQt 如何创建自定义QWidget
Mar 24 Python
粗暴解决CUDA out of memory的问题
May 22 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
带密匙的php加密解密示例分享
2014/01/29 PHP
Linux下安装PHP MSSQL扩展教程
2014/10/24 PHP
PHP模板引擎Smarty之配置文件在模板变量中的使用方法示例
2016/04/11 PHP
php中钩子(hook)的原理与简单应用demo示例
2019/09/03 PHP
javascript 中对象的继承〔转贴〕
2007/01/22 Javascript
JavaScript 变量命名规则
2009/09/23 Javascript
jQuery学习笔记 操作jQuery对象 属性处理
2012/09/19 Javascript
Jquery动态更改一张位图的src与Attr的使用
2013/07/31 Javascript
深入理解JavaScript系列(39):设计模式之适配器模式详解
2015/03/04 Javascript
JavaScript获得当前网页来源页面(即上一页)的方法
2015/04/03 Javascript
javascript中callee与caller的区别分析
2015/04/20 Javascript
jQuery带进度条全屏图片轮播特效代码分享
2020/06/28 Javascript
Angular实现form自动布局
2016/01/28 Javascript
纯js实现html转pdf的简单实例(推荐)
2017/02/16 Javascript
JS中Swiper的使用和轮播图效果
2017/08/11 Javascript
微信小程序出现wx.navigateTo页面不跳转问题的解决方法
2017/12/26 Javascript
JavaScript去掉数组重复项的方法分析【测试可用】
2018/07/19 Javascript
koa2的中间件功能及应用示例
2020/03/05 Javascript
javascript局部自定义鼠标右键菜单
2020/12/08 Javascript
[38:54]完美世界DOTA2联赛PWL S2 Rebirth vs LBZS 第一场 11.28
2020/12/01 DOTA
python 域名分析工具实现代码
2009/07/15 Python
Using Django with GAE Python 后台抓取多个网站的页面全文
2016/02/17 Python
Python实现删除文件中含“指定内容”的行示例
2017/06/09 Python
Python使用微信接入图灵机器人过程解析
2019/11/04 Python
Python编译成.so文件进行加密后调用的实现
2019/12/23 Python
python图形界面开发之wxPython树控件使用方法详解
2020/02/24 Python
python 利用jieba.analyse进行 关键词提取
2020/12/17 Python
迪奥美国官网:Dior美国
2019/12/07 全球购物
eDreams德国:南欧领先的在线旅游公司
2020/12/07 全球购物
博士生入学考试推荐信
2013/11/17 职场文书
报到证办理个人委托书
2014/10/06 职场文书
初中英语教师个人工作总结
2015/02/09 职场文书
《活见鬼》教学反思
2016/02/24 职场文书
毕业生就业推荐表自我鉴定
2019/06/20 职场文书
利用ajax+php实现商品价格计算
2021/03/31 PHP
Django rest framework如何自定义用户表
2021/06/09 Python