Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的一个简单LRU cache
Sep 26 Python
仅用500行Python代码实现一个英文解析器的教程
Apr 02 Python
python读写二进制文件的方法
May 09 Python
以一个投票程序的实例来讲解Python的Django框架使用
Feb 18 Python
深入分析python中整型不会溢出问题
Jun 18 Python
详解python while 函数及while和for的区别
Sep 07 Python
Python利用heapq实现一个优先级队列的方法
Feb 03 Python
Python 动态导入对象,importlib.import_module()的使用方法
Aug 28 Python
Python Django 前后端分离 API的方法
Aug 28 Python
python dataframe NaN处理方式
Dec 26 Python
Python-jenkins 获取job构建信息方式
May 12 Python
python如何利用cv2.rectangle()绘制矩形框
Dec 24 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
phpwind中的数据库操作类
2007/01/02 PHP
从PHP $_SERVER相关参数判断是否支持Rewrite模块
2013/09/26 PHP
javascript比较文档位置
2008/04/08 Javascript
JSON 入门指南 想了解json的朋友可以看下
2009/08/26 Javascript
Javascript跨域请求的4种解决方式
2013/03/17 Javascript
jquery插件lazyload.js延迟加载图片的使用方法
2014/02/19 Javascript
javascript中expression的用法整理
2014/05/13 Javascript
jQuery中(function($){})(jQuery)详解
2015/07/15 Javascript
跟我学习javascript的异步脚本加载
2015/11/20 Javascript
jQuery+Ajax+PHP弹出层异步登录效果(附源码下载)
2016/05/27 Javascript
AngularJS基础 ng-non-bindable 指令详细介绍
2016/08/02 Javascript
Javascript基于jQuery UI实现选中区域拖拽效果
2016/11/25 Javascript
基于JavaScript实现全选、不选和反选效果
2017/02/15 Javascript
微信小程序中子页面向父页面传值实例详解
2017/03/20 Javascript
Node.js实现文件上传的示例
2017/06/28 Javascript
简单谈谈js的数据类型
2017/09/25 Javascript
3种vue组件的书写形式
2017/11/29 Javascript
vue自定义移动端touch事件之点击、滑动、长按事件
2018/07/10 Javascript
JS封装的模仿qq右下角消息弹窗功能示例
2018/08/22 Javascript
JavaScript中的事件与异常捕获详析
2019/02/24 Javascript
React路由鉴权的实现方法
2019/09/05 Javascript
使用webpack搭建pixi.js开发环境
2020/02/12 Javascript
vue + el-form 实现的多层循环表单验证
2020/11/25 Vue.js
python3中rank函数的用法
2019/11/27 Python
解决安装新版PyQt5、PyQT5-tool后打不开并Designer.exe提示no Qt platform plugin的问题
2020/04/24 Python
Python xlrd模块导入过程及常用操作
2020/06/10 Python
css3隔行变换色实现示例
2014/02/19 HTML / CSS
html5 figure和figcaption的使用方法
2018/09/10 HTML / CSS
以设计师精品品质提供快速时尚:PopJulia
2018/01/09 全球购物
日本最大美瞳直送网:Morecontact(中文)
2019/04/03 全球购物
老师给学生的表扬信
2014/01/17 职场文书
影视广告专业求职信
2014/09/02 职场文书
迎新生标语大全
2014/10/06 职场文书
2015年推广普通话演讲稿
2015/03/20 职场文书
MySQL和Oracle批量插入SQL的通用写法示例
2021/11/17 MySQL