Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现倒计时的示例
Feb 14 Python
python代码制作configure文件示例
Jul 28 Python
Python常用正则表达式符号浅析
Aug 13 Python
在Python上基于Markov链生成伪随机文本的教程
Apr 17 Python
Python编程中time模块的一些关键用法解析
Jan 19 Python
Django如何实现内容缓存示例详解
Sep 24 Python
Python实现的破解字符串找茬游戏算法示例
Sep 25 Python
Python键盘输入转换为列表的实例
Jun 23 Python
pytorch模型预测结果与ndarray互转方式
Jan 15 Python
python统计函数库scipy.stats的用法解析
Feb 25 Python
ASP.NET Core中的配置详解
Feb 05 Python
python实现自动化群控的步骤
Apr 11 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
windows下升级PHP到5.3.3的过程及注意事项
2010/10/12 PHP
PHP通过iconv将字符串从GBK转换为UTF8字符集
2011/07/18 PHP
如何给phpcms v9增加类似于phpcms 2008中的关键词表
2013/07/01 PHP
js继承 Base类的源码解析
2008/12/30 Javascript
js 事件处理函数间的Event物件是否全等
2011/04/08 Javascript
《JavaScript高级程序设计》阅读笔记(一) ECMAScript基础
2012/02/27 Javascript
JQUERY 设置SELECT选中项代码
2014/02/07 Javascript
Select标签下拉列表二级联动级联实例代码
2014/02/07 Javascript
jQuery移除元素自动解绑事件实现思路及代码
2014/05/31 Javascript
JavaScript学习笔记之内置对象
2015/01/22 Javascript
JavaScript实现上下浮动的窗口效果代码
2015/10/12 Javascript
JS+CSS实现的经典圆角下拉菜单效果代码
2015/10/21 Javascript
jQuery插件EasyUI校验规则 validatebox验证框
2015/11/29 Javascript
剖析Node.js异步编程中的回调与代码设计模式
2016/02/16 Javascript
详解JavaScript中|单竖杠运算符的使用方法
2016/05/23 Javascript
使用BootStrapValidator完成前端输入验证
2016/09/28 Javascript
对称加密与非对称加密优缺点详解
2017/02/06 Javascript
Bootstrap禁用响应式布局的实现方法
2017/03/09 Javascript
vue 实现全选全不选的示例代码
2018/03/29 Javascript
vue cli3适配所有端方案的实现
2020/04/13 Javascript
Python基于回溯法子集树模板解决找零问题示例
2017/09/11 Python
Django查询数据库的性能优化示例代码
2017/09/24 Python
深入理解Django的自定义过滤器
2017/10/17 Python
python使用turtle绘制分形树
2018/06/22 Python
详解Python 解压缩文件
2019/04/09 Python
解决Pycharm后台indexing导致不能run的问题
2019/06/27 Python
对python中不同模块(函数、类、变量)的调用详解
2019/07/16 Python
在django view中给form传入参数的例子
2019/07/19 Python
解决python Jupyter不能导入外部包问题
2020/04/15 Python
python 中的9个实用技巧,助你提高开发效率
2020/08/30 Python
CSS3实现超酷的黑猫警长首页
2016/04/26 HTML / CSS
css3中仿放大镜效果的几种方式原理解析
2020/12/03 HTML / CSS
会计工作心得体会
2014/01/13 职场文书
医药营销个人求职信范文
2014/02/07 职场文书
村委会换届选举方案
2014/05/03 职场文书
关于Nginx中虚拟主机的一些冷门知识小结
2022/03/03 Servers