Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现将n个点均匀地分布在球面上的方法
Mar 12 Python
Python的迭代器和生成器
Jul 29 Python
python中异常捕获方法详解
Mar 03 Python
Python 由字符串函数名得到对应的函数(实例讲解)
Aug 10 Python
Python使用回溯法子集树模板获取最长公共子序列(LCS)的方法
Sep 08 Python
python flask实现分页的示例代码
Aug 02 Python
Python解释器及PyCharm工具安装过程
Feb 26 Python
Python3 pickle对象串行化代码实例解析
Mar 23 Python
Anconda环境下Vscode安装Python的方法详解
Mar 29 Python
Python实现简单的猜单词小游戏
Oct 28 Python
python try...finally...的实现方法
Nov 25 Python
Python关于拓扑排序知识点讲解
Jan 04 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
ubuntu 编译安装php 5.3.3+memcache的方法
2010/08/05 PHP
PHP实现的简单mock json脚本分享
2015/02/10 PHP
php实现的http请求封装示例
2016/11/08 PHP
PHP中SQL查询语句的id=%d解释(推荐)
2016/12/10 PHP
jquery中常用的SET和GET$(”#msg”).html循环介绍
2013/10/09 Javascript
JavaScript对象之深度克隆介绍
2014/12/08 Javascript
jquery实现向下滑出的二级导航下滑菜单效果
2015/08/25 Javascript
JS创建对象的写法示例
2016/11/04 Javascript
利用JS实现简单的日期选择插件
2017/01/23 Javascript
浅谈事件冒泡、事件委托、jQuery元素节点操作、滚轮事件与函数节流
2017/07/22 jQuery
jQuery操作attr、prop、val()/text()/html()、class属性
2019/05/23 jQuery
Layui表格监听行单双击事件讲解
2019/11/14 Javascript
python实现监控windows服务并自动启动服务示例
2014/04/17 Python
python根据开头和结尾字符串获取中间字符串的方法
2015/03/26 Python
解决Python中由于logging模块误用导致的内存泄露
2015/04/23 Python
python3实现短网址和数字相互转换的方法
2015/04/28 Python
Python二分法搜索算法实例分析
2015/05/11 Python
简介Python中用于处理字符串的center()方法
2015/05/18 Python
Python导入oracle数据的方法
2015/07/10 Python
深入讲解Python编程中的字符串
2015/10/14 Python
python 使用shutil复制图片的例子
2019/12/13 Python
Win下PyInstaller 安装和使用教程
2019/12/25 Python
浅谈css3中calc在less编译时被计算的解决办法
2017/12/04 HTML / CSS
杰夫·班克斯男士服装网上商店:Jeff Banks
2019/10/24 全球购物
通信工程专业个人找工作求职信范文
2013/09/21 职场文书
物流专业大学生的自我鉴定
2013/11/13 职场文书
会计出纳员的自我评价
2014/01/15 职场文书
学生会招新策划书
2014/02/14 职场文书
《将心比心》教学反思
2014/04/08 职场文书
终止劳动合同协议书
2014/04/14 职场文书
离职证明标准格式
2014/09/15 职场文书
有限责任公司股东合作协议书范本
2014/10/30 职场文书
先进教师事迹材料
2014/12/16 职场文书
培训班通知
2015/04/25 职场文书
党员干部学习三严三实心得体会
2016/01/05 职场文书
《七律·长征》教学反思
2016/02/16 职场文书