Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用循环实现批量创建文件夹示例
Mar 25 Python
python中lambda函数 list comprehension 和 zip函数使用指南
Sep 28 Python
Python Numpy 数组的初始化和基本操作
Mar 13 Python
Python zip()函数用法实例分析
Mar 17 Python
Python装饰器用法实例总结
May 26 Python
Python基于OpenCV库Adaboost实现人脸识别功能详解
Aug 25 Python
Python对ElasticSearch获取数据及操作
Apr 24 Python
Python pip替换为阿里源的方法步骤
Jul 02 Python
Python实现大数据收集至excel的思路详解
Jan 03 Python
使用Python合成图片的实现代码(图片添加个性化文本,图片上叠加其他图片)
Apr 30 Python
Pycharm无法打开双击没反应的问题及解决方案
Aug 17 Python
Python爬虫模拟登陆哔哩哔哩(bilibili)并突破点选验证码功能
Dec 21 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
深入PHP许愿墙模块功能分析
2013/06/25 PHP
PHP利用MySQL保存session的实现思路及示例代码
2014/09/09 PHP
浅谈php的TS和NTS的区别
2019/03/13 PHP
PHP 面向对象程序设计之类属性与类常量实现方法分析
2020/04/13 PHP
基于jQuery的倒计时实现代码
2012/05/30 Javascript
Jquery技巧(必须掌握)
2016/03/16 Javascript
Bootstrap+jfinal实现省市级联下拉菜单
2016/05/30 Javascript
JS实现兼容各种浏览器的高级拖动方法完整实例【测试可用】
2016/06/21 Javascript
浅谈JS继承_借用构造函数 & 组合式继承
2016/08/16 Javascript
ES6 javascript的异步操作实例详解
2017/10/30 Javascript
浅谈Vue 数据响应式原理
2018/05/07 Javascript
React项目动态设置title标题的方法示例
2018/09/26 Javascript
Webpack中SplitChunksPlugin 配置参数详解
2020/03/24 Javascript
vue3.0+vue-router+element-plus初实践
2020/12/02 Vue.js
[01:03:37]Secret vs VGJ.S Supermajor小组赛C组 BO3 第二场 6.3
2018/06/04 DOTA
[50:50]完美世界DOTA2联赛PWL S3 Galaxy Racer vs Phoenix 第一场 12.10
2020/12/13 DOTA
python二叉树的实现实例
2013/11/21 Python
Python深入学习之特殊方法与多范式
2014/08/31 Python
Python的Bottle框架中获取制定cookie的教程
2015/04/24 Python
对Python进行数据分析_关于Package的安装问题
2017/05/22 Python
Python读取视频的两种方法(imageio和cv2)
2018/04/15 Python
使用 Python 实现文件递归遍历的三种方式
2018/07/18 Python
python实现大转盘抽奖效果
2019/01/22 Python
python 采用paramiko 远程执行命令及报错解决
2019/10/21 Python
Python 中的pygame安装与配置教程详解
2020/02/10 Python
Python with语句用法原理详解
2020/07/03 Python
css3实现图片遮罩效果鼠标hover以后出现文字
2013/11/05 HTML / CSS
图解CSS3制作圆环形进度条的实例教程
2016/05/26 HTML / CSS
电子商务专业学生的学习自我评价
2013/10/27 职场文书
韩国商务邀请函
2014/01/14 职场文书
《草原》教学反思
2014/02/15 职场文书
励志演讲稿300字
2014/08/21 职场文书
刑事代理授权委托书
2014/09/17 职场文书
小学生优秀评语
2014/12/29 职场文书
酒店保洁员岗位职责
2015/02/26 职场文书
班主任培训研修日志
2015/11/13 职场文书