keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python设置Socket代理及实现远程摄像头控制的例子
Nov 13 Python
Python检测网站链接是否已存在
Apr 07 Python
浅谈django开发者模式中的autoreload是如何实现的
Aug 18 Python
Python元组知识点总结
Feb 18 Python
基于python的socket实现单机五子棋到双人对战
Mar 24 Python
Python检查图片是否损坏及图片类型是否正确过程详解
Sep 30 Python
python生成器推导式用法简单示例
Oct 08 Python
python、Matlab求定积分的实现
Nov 20 Python
自定义Django默认的sitemap站点地图样式
Mar 04 Python
读取nii或nii.gz文件中的信息即输出图像操作
Jul 01 Python
Python常用数据分析模块原理解析
Jul 20 Python
python+openCV对视频进行截取的实现
Nov 27 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
通过对php一些服务器端特性的配置加强php的安全
2006/10/09 PHP
PHP扩展模块Pecl、Pear以及Perl的区别
2014/04/09 PHP
PHP中isset与array_key_exists的区别实例分析
2015/06/02 PHP
纯php生成随机密码
2015/10/30 PHP
PHP面向对象程序设计中的self、static、parent关键字用法分析
2019/08/14 PHP
jquery 笔记 事件
2011/11/02 Javascript
JavaScript的事件绑定(方便不支持js的时候)
2013/10/01 Javascript
AngularJs根据访问的页面动态加载Controller的解决方案
2015/02/04 Javascript
javascript获取重复次数最多的字符
2015/07/08 Javascript
jquery分隔Url的param方法(推荐)
2016/05/25 Javascript
三种带箭头提示框总结实例
2016/06/14 Javascript
jQuery插件EasyUI获取当前Tab中iframe窗体对象的方法
2016/08/05 Javascript
从零学习node.js之文件操作(三)
2017/02/21 Javascript
elemetUi 组件--el-upload实现上传Excel文件的实例
2017/10/27 Javascript
koa-router路由参数和前端路由的结合详解
2019/05/19 Javascript
js将URL网址转为16进制加密与解密函数
2020/03/04 Javascript
jQuery AJAX应用实例总结
2020/05/19 jQuery
跟老齐学Python之画圈还不简单吗?
2014/09/20 Python
python标准算法实现数组全排列的方法
2015/03/17 Python
Python中的if、else、elif语句用法简明讲解
2016/03/11 Python
Python处理JSON时的值报错及编码报错的两则解决实录
2016/06/26 Python
使用TensorFlow实现二分类的方法示例
2019/02/05 Python
python中property属性的介绍及其应用详解
2019/08/29 Python
python 命令行传入参数实现解析
2019/08/30 Python
python使用opencv实现马赛克效果示例
2019/09/28 Python
Django使用list对单个或者多个字段求values值实例
2020/03/31 Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
2020/06/12 Python
Django项目创建及管理实现流程详解
2020/10/13 Python
汽车检测与维修个人求职信
2013/09/24 职场文书
房地产出纳岗位职责
2013/12/01 职场文书
应聘自荐信
2013/12/14 职场文书
五型班组建设方案
2014/02/10 职场文书
颁奖晚会主持词
2014/03/25 职场文书
教师廉政准则心得体会
2016/01/20 职场文书
《文化苦旅》读后感:阅读,让人诗意地栖居在大地上
2019/12/24 职场文书
不同品牌、不同型号对讲机如何互相通联
2022/02/18 无线电