Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3.x和Python2.x的区别介绍
Feb 12 Python
使用python编写脚本获取手机当前应用apk的信息
Jul 21 Python
Python判断变量是否已经定义的方法
Aug 18 Python
使用Python进行二进制文件读写的简单方法(推荐)
Sep 12 Python
Python3.7中安装openCV库的方法
Jul 11 Python
python根据文章标题内容自动生成摘要的实例
Feb 21 Python
Django框架验证码用法实例分析
May 10 Python
Python 中 -m 的典型用法、原理解析与发展演变
Nov 11 Python
Pytorch 使用不同版本的cuda的方法步骤
Apr 02 Python
Python一行代码实现自动发邮件功能
May 30 Python
常用的Python代码调试工具总结
Jun 23 Python
Pytest中skip和skipif的具体使用方法
Jun 30 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
全国FM电台频率大全 - 1 北京市
2020/03/11 无线电
php根据日期判断星座的函数分享
2014/02/13 PHP
PHP超全局数组(Superglobals)介绍
2015/07/01 PHP
thinkPHP框架实现的简单计算器示例
2018/12/07 PHP
PHP实现数组根据某个字段进行水平合并,横向合并案例分析
2019/10/08 PHP
解javascript 混淆加密收藏
2009/01/16 Javascript
node.js中的fs.appendFile方法使用说明
2014/12/17 Javascript
jQuery中ajax的post()方法用法实例
2014/12/26 Javascript
根据配置文件加载js依赖模块
2014/12/29 Javascript
js读取csv文件并使用json显示出来
2015/01/09 Javascript
JavaScript中的fontsize()方法使用详解
2015/06/08 Javascript
jQuery实现简单的点赞效果
2020/05/29 Javascript
浅谈移动端之js touch事件 手势滑动事件
2016/11/07 Javascript
判断横屏竖屏(三种)
2017/02/13 Javascript
细说webpack源码之compile流程-rules参数处理技巧(1)
2017/12/26 Javascript
vue项目实现表单登录页保存账号和密码到cookie功能
2018/08/31 Javascript
node删除、复制文件或文件夹示例代码
2019/08/13 Javascript
提升Python程序运行效率的6个方法
2015/03/31 Python
编写Python脚本来获取mp3文件tag信息的教程
2015/05/04 Python
python类继承用法实例分析
2015/05/27 Python
Python3爬虫之自动查询天气并实现语音播报
2019/02/21 Python
pyqt 实现QlineEdit 输入密码显示成圆点的方法
2019/06/24 Python
Python同时迭代多个序列的方法
2020/07/28 Python
Python web框架(django,flask)实现mysql数据库读写分离的示例
2020/11/18 Python
python实现简单的井字棋游戏(gui界面)
2021/01/22 Python
降消项目实施方案
2014/03/30 职场文书
电话客服工作职责
2014/07/27 职场文书
学校创先争优活动总结
2014/08/28 职场文书
房屋租赁合同补充协议
2014/10/11 职场文书
工程部岗位职责
2015/02/10 职场文书
热血教师观后感
2015/06/10 职场文书
详解CocosCreator消息分发机制
2021/04/16 Javascript
源码解读Spring-Integration执行过程
2021/06/11 Java/Android
php去除数组中为0的元素的实例分析
2021/11/17 PHP
在vue中import()语法不能传入变量的问题及解决
2022/04/01 Vue.js
Python 装饰器(decorator)常用的创建方式及解析
2022/04/24 Python