Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python定时采集摄像头图像上传ftp服务器功能实现
Dec 23 Python
Python中的map、reduce和filter浅析
Apr 26 Python
python教程之用py2exe将PY文件转成EXE文件
Jun 12 Python
Python 执行字符串表达式函数(eval exec execfile)
Aug 11 Python
python实现2048小游戏
Mar 30 Python
python基础知识小结之集合
Nov 25 Python
python2与python3中关于对NaN类型数据的判断和转换方法
Oct 30 Python
python随机生成大小写字母数字混合密码(仅20行代码)
Feb 01 Python
将pytorch转成longtensor的简单方法
Feb 18 Python
借助Paramiko通过Python实现linux远程登陆及sftp的操作
Mar 16 Python
Python之qq自动发消息的示例代码
Feb 18 Python
Python获取百度热搜的完整代码
Apr 07 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
php实现删除空目录的方法
2015/03/16 PHP
php实现window平台的checkdnsrr函数
2015/05/27 PHP
PHP抽象类与接口的区别详解
2019/03/21 PHP
PHP连接MySQL数据库的三种方式实例分析【mysql、mysqli、pdo】
2019/11/04 PHP
javascript-TreeView父子联动效果保持节点状态一致
2007/08/12 Javascript
javascript 单选框,多选框美化代码
2008/08/01 Javascript
对于this和$(this)的个人理解
2013/09/08 Javascript
Node.js(安装,启动,测试)
2014/06/09 Javascript
jQuery插件slick实现响应式移动端幻灯片图片切换特效
2015/04/12 Javascript
jQuery自定义插件详解及实例代码
2016/12/29 Javascript
axios基本入门用法教程
2017/03/25 Javascript
微信小程序中做用户登录与登录态维护的实现详解
2017/05/17 Javascript
AngularJS中使用ngModal模态框实例
2017/05/27 Javascript
WebStorm ES6 语法支持设置&babel使用及自动编译(详解)
2017/09/08 Javascript
vue.js前后端数据交互之提交数据操作详解
2018/04/24 Javascript
浅谈vue中get请求解决传输数据是数组格式的问题
2020/08/03 Javascript
[48:26]VGJ.S vs infamous Supermajor 败者组 BO3 第二场 6.4
2018/06/05 DOTA
Flask之请求钩子的实现
2018/12/23 Python
Python子类继承父类构造函数详解
2019/02/19 Python
python里dict变成list实例方法
2019/06/26 Python
python 中pyqt5 树节点点击实现多窗口切换问题
2019/07/04 Python
如何把python项目部署到linux服务器
2020/08/26 Python
HTML中fieldset标签概述及使用方法
2013/02/01 HTML / CSS
英国Office鞋店德国网站:在线购买鞋子、靴子和运动鞋
2018/12/19 全球购物
全球最大运动品牌的男装、女装和童装官方库存商:A&A Sports
2021/01/17 全球购物
财务会计专业推荐信
2013/11/30 职场文书
初中三好学生事迹材料
2014/01/13 职场文书
采购经理岗位职责
2014/02/16 职场文书
工程造价专业大学生职业规划范文
2014/03/09 职场文书
团支书竞选演讲稿
2014/04/28 职场文书
吃空饷专项整治方案
2014/10/27 职场文书
优秀团员个人总结
2015/02/26 职场文书
仓管员岗位职责范本
2015/04/01 职场文书
个人借条范本
2015/05/25 职场文书
旅游安全责任协议书
2016/03/22 职场文书
Python实战实现爬取天气数据并完成可视化分析详解
2022/06/16 Python