Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Django的URLconf中进行函数导入的方法
Jul 18 Python
Python微信库:itchat的用法详解
Aug 14 Python
Php多进程实现代码
May 07 Python
python队列Queue的详解
May 10 Python
Python3批量生成带logo的二维码方法
Jun 24 Python
利用python开发app实战的方法
Jul 09 Python
Python pandas实现excel工作表合并功能详解
Aug 29 Python
python实现广度优先搜索过程解析
Oct 19 Python
python os模块在系统管理中的应用
Jun 22 Python
Python编写单元测试代码实例
Sep 10 Python
Python自动化测试基础必备知识点总结
Feb 07 Python
详解Python requests模块
Jun 21 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
PHP&MYSQL服务器配置说明
2006/10/09 PHP
正则表达式语法
2006/10/09 Javascript
使用PHP 5.0创建图形的巧妙方法
2010/10/12 PHP
php绘制一个扇形的方法
2015/01/24 PHP
用javascript获取地址栏参数
2006/12/22 Javascript
初学Jquery插件制作 在SageCRM的查询屏幕隐藏部分行的功能
2011/12/26 Javascript
纯js网页画板(Graphics)类简介及实现代码
2012/12/24 Javascript
js函数排序的实例代码
2013/07/01 Javascript
JS实现根据出生年月计算年龄
2014/01/10 Javascript
ie9 提示'console' 未定义问题的解决方法
2014/03/20 Javascript
javascript搜索框效果实现方法
2015/05/14 Javascript
javascript实现3D变换的立体圆圈实例
2015/08/06 Javascript
Node.js返回JSONP详解
2016/05/18 Javascript
JavaScript prototype属性详解
2016/10/25 Javascript
JavaScript中${pageContext.request.contextPath}取值问题及解决方案
2016/12/08 Javascript
基于JavaScript实现无限加载瀑布流
2017/07/21 Javascript
bootstrap日期插件daterangepicker使用详解
2017/10/19 Javascript
PHPStorm中如何对nodejs项目进行单元测试详解
2019/02/28 NodeJs
JS回调函数简单易懂的入门实例分析
2019/09/29 Javascript
[36:41]完美世界DOTA2联赛循环赛FTD vs Magma第一场 10月30日
2020/10/31 DOTA
pandas 使用apply同时处理两列数据的方法
2018/04/20 Python
Python中文件的写入读取以及附加文字方法
2019/01/23 Python
Python实现截取PDF文件中的几页代码实例
2019/03/11 Python
Python-Seaborn热图绘制的实现方法
2019/07/15 Python
使用pip安装python库的多种方式
2019/07/31 Python
python和pywin32实现窗口查找、遍历和点击的示例代码
2020/04/01 Python
python的json包位置及用法总结
2020/06/21 Python
django rest framework 自定义返回方式
2020/07/12 Python
Html5+JS实现手机摇一摇功能
2015/04/24 HTML / CSS
世界上最大的家庭自动化公司:Smarthome
2017/12/20 全球购物
Tirendo比利时:在线购买轮胎
2018/10/22 全球购物
请问软件开发中的设计模式你会使用哪些
2015/05/13 面试题
十佳青年事迹材料
2014/08/21 职场文书
幼儿园八一建军节活动方案
2014/08/27 职场文书
学校联谊协议书
2014/09/16 职场文书
JPA 通过Specification如何实现复杂查询
2021/11/23 Java/Android