Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
可用于监控 mysql Master Slave 状态的python代码
Feb 10 Python
python安装mysql-python简明笔记(ubuntu环境)
Jun 25 Python
pycharm中连接mysql数据库的步骤详解
May 02 Python
python pandas中DataFrame类型数据操作函数的方法
Apr 08 Python
Python学习笔记之视频人脸检测识别实例教程
Mar 06 Python
python小程序实现刷票功能详解
Jul 17 Python
Python实现的爬取豆瓣电影信息功能案例
Sep 15 Python
python查看数据类型的方法
Oct 12 Python
python pptx复制指定页的ppt教程
Feb 14 Python
keras的siamese(孪生网络)实现案例
Jun 12 Python
python绘制趋势图的示例
Sep 17 Python
接口自动化多层嵌套json数据处理代码实例
Nov 20 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
检测png图片是否完整的php代码
2010/09/06 PHP
一个显示某段时间内每个月的方法 返回由这些月份组成的数组
2012/05/16 PHP
解决PHP4.0 和 PHP5.0类构造函数的兼容问题
2013/08/01 PHP
php如何实现只替换一次或N次
2015/10/29 PHP
页面装载js及性能分析方法介绍
2014/03/21 Javascript
全面兼容的javascript时间格式化函数(比较实用)
2014/05/14 Javascript
JavaScript继承基础讲解(原型链、借用构造函数、混合模式、原型式继承、寄生式继承、寄生组合式继承)
2014/08/16 Javascript
基于jQuery实现复选框的全选 全不选 反选功能
2014/11/24 Javascript
深入分析javascript中的错误处理机制
2016/07/17 Javascript
BootStrap 下拉菜单点击之后不会出现下拉菜单(下拉菜单不弹出)的解决方案
2016/12/14 Javascript
浅谈javascript中的事件冒泡和事件捕获
2016/12/28 Javascript
react开发教程之React 组件之间的通信方式
2017/08/12 Javascript
IntersectionObserver实现图片懒加载的示例
2017/09/29 Javascript
js判断文件类型大小并给出提示的实现方法
2018/01/03 Javascript
详解Vue 全局引入bass.scss 处理方案
2018/03/26 Javascript
js计时事件实现圆形时钟
2020/03/25 Javascript
[54:47]Liquid vs VP Supermajor决赛 BO 第五场 6.10
2018/07/05 DOTA
用python标准库difflib比较两份文件的异同详解
2018/11/16 Python
Python对接 xray 和微信实现自动告警
2019/09/17 Python
django2.2 和 PyMySQL版本兼容问题
2020/02/17 Python
Python argparse模块使用方法解析
2020/02/20 Python
Python 去除字符串中指定字符串
2020/03/05 Python
Python字符串hashlib加密模块使用案例
2020/03/10 Python
python uuid生成唯一id或str的最简单案例
2021/01/13 Python
一款恶搞头像特效的制作过程 利用css3和jquery
2014/11/21 HTML / CSS
荷兰网上鞋店:Ziengs.nl
2017/01/02 全球购物
新加坡最佳婴儿用品店:Mamahood.com.sg
2018/08/26 全球购物
用C#语言写出与SQLSERVER访问时的具体过程
2013/04/16 面试题
建筑工程实习自我鉴定
2013/09/19 职场文书
中医药大学毕业生自荐信
2013/11/08 职场文书
物流仓管员岗位职责
2013/12/04 职场文书
班队活动设计方案
2014/01/30 职场文书
见习报告格式要求
2014/11/04 职场文书
深入理解java.lang.String类的不可变性
2021/06/27 Java/Android
ajax请求前端跨域问题原因及解决方案
2021/10/16 Javascript
PYTHON 使用 Pandas 删除某列指定值所在的行
2022/04/28 Python