Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python抓取京东价格分析京东商品价格走势
Jan 09 Python
下载给定网页上图片的方法
Feb 18 Python
python中执行shell的两种方法总结
Jan 10 Python
Python中使用支持向量机(SVM)算法
Dec 26 Python
Python读csv文件去掉一列后再写入新的文件实例
Dec 28 Python
Python获取当前公网ip并自动断开宽带连接实例代码
Jan 12 Python
浅谈Python Opencv中gamma变换的使用详解
Apr 02 Python
python django下载大的csv文件实现方法分析
Jul 19 Python
Anaconda+Pycharm环境下的PyTorch配置方法
Mar 13 Python
anaconda3安装及jupyter环境配置全教程
Aug 24 Python
Python调用JavaScript代码的方法
Oct 27 Python
Django模型层实现多表关系创建和多表操作
Jul 21 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
PHP实现生成带背景的图形验证码功能
2016/10/03 PHP
详谈PHP面向对象中常用的关键字和魔术方法
2017/02/04 PHP
理解JavaScript的prototype属性
2012/02/11 Javascript
JS定时关闭窗口的实例
2013/05/22 Javascript
js闭包的用途详解
2014/11/09 Javascript
javascript查询字符串参数的方法
2015/01/28 Javascript
删除javascript所创建子节点的方法
2015/05/21 Javascript
swtich/if...else的替代语句
2015/08/16 Javascript
javascript动态添加checkbox复选框的方法
2015/12/23 Javascript
jQuery Mobile和HTML5开发App推广注册页
2016/11/07 Javascript
微信小程序的日期选择器的实例详解
2017/09/29 Javascript
微信小程序实现定位及到指定位置导航的示例代码
2019/08/20 Javascript
vue动态加载SVG文件并修改节点数据的操作代码
2020/08/17 Javascript
javascript canvas实现简易时钟例子
2020/09/05 Javascript
[01:16:12]完美世界DOTA2联赛PWL S2 FTD vs Inki 第一场 11.21
2020/11/23 DOTA
django 自定义用户user模型的三种方法
2014/11/18 Python
在Python的Django框架中实现Hacker News的一些功能
2015/04/17 Python
Python 模拟登陆的两种实现方法
2017/08/10 Python
在Python中定义一个常量的方法
2018/11/10 Python
python zip()函数使用方法解析
2019/10/31 Python
Python 切分数组实例解析
2019/11/07 Python
pygame库实现移动底座弹球小游戏
2020/04/14 Python
解决Pytorch训练过程中loss不下降的问题
2020/01/02 Python
浅谈python3打包与拆包在函数的应用详解
2020/05/02 Python
python中re模块知识点总结
2021/01/17 Python
基于canvas的骨骼动画的示例代码
2018/06/12 HTML / CSS
美国排名第一的在线葡萄酒商店:Wine.com
2016/09/07 全球购物
Linux管理员面试经常问道的相关命令
2014/12/12 面试题
竞争上岗演讲稿
2014/01/05 职场文书
绿化先进工作者事迹材料
2014/01/30 职场文书
教师求职信范文
2014/05/24 职场文书
警告通知
2015/04/25 职场文书
班主任工作总结范文
2015/08/13 职场文书
开机音效回归! Windows 11重新引入开机铃声
2021/11/21 数码科技
Nginx动静分离配置实现与说明
2022/04/07 Servers
python 闭包函数详细介绍
2022/04/19 Python