Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过shutil实现快速文件复制的方法
Mar 14 Python
利用QT写一个极简单的图形化Python闹钟程序
Apr 07 Python
python如何通过实例方法名字调用方法
Mar 21 Python
Python实现端口检测的方法
Jul 24 Python
Python常见内置高效率函数用法示例
Jul 31 Python
Pyqt5如何让QMessageBox按钮显示中文示例代码
Apr 11 Python
python3+PyQt5 数据库编程--增删改实例
Jun 17 Python
python打印9宫格、25宫格等奇数格 满足横竖斜相加和相等
Jul 19 Python
多个python文件调用logging模块报错误
Feb 12 Python
python中查看.db文件中表格的名字及表格中的字段操作
Jul 07 Python
python爬虫scrapy框架的梨视频案例解析
Feb 20 Python
python自动化八大定位元素讲解
Jul 09 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
基于Discuz security.inc.php代码的深入分析
2013/06/03 PHP
如何在php中正确的使用json
2013/08/06 PHP
php中heredoc与nowdoc介绍
2014/12/25 PHP
取得一定长度的内容,处理中文
2006/12/20 Javascript
原生JS实现自定义下拉单选选择框功能
2018/10/12 Javascript
优雅的在React项目中使用Redux的方法
2018/11/10 Javascript
微信小程序自定义导航栏
2018/12/31 Javascript
[02:08]我的刀塔不可能这么可爱 胡晓桃_1
2014/06/20 DOTA
python中将字典转换成其json字符串
2014/07/16 Python
Windows上使用virtualenv搭建Python+Flask开发环境
2016/06/07 Python
python爬取NUS-WIDE数据库图片
2016/10/05 Python
Python2和Python3.6环境解决共存问题
2018/11/09 Python
说说如何遍历Python列表的方法示例
2019/02/11 Python
详解Ubuntu16.04安装Python3.7及其pip3并切换为默认版本
2019/02/25 Python
Python使用random模块生成随机数操作实例详解
2019/09/17 Python
python如何从文件读取数据及解析
2019/09/19 Python
Python全面分析系统的时域特性和频率域特性
2020/02/26 Python
使用jupyter notebook直接打开.md格式的文件
2020/04/10 Python
Django ORM实现按天获取数据去重求和例子
2020/05/18 Python
Microsoft新加坡官方网站:购买微软最新软件和技术产品
2016/10/28 全球购物
捷克钓鱼用品网上商店:Parys.cz
2018/06/15 全球购物
HEMA英国:荷兰原创设计
2018/08/28 全球购物
Java和Javasciprt的区别
2012/09/02 面试题
个人自我鉴定写法
2013/11/30 职场文书
党员思想汇报范文
2013/12/30 职场文书
数控专业毕业生自荐信范文
2014/03/04 职场文书
教师个人读书活动总结
2014/07/08 职场文书
地震捐款倡议书
2014/08/29 职场文书
领导干部学习“三严三实”思想汇报
2014/09/15 职场文书
“四风”问题的主要表现和危害思想汇报
2014/09/19 职场文书
不遵守课堂纪律的检讨书
2014/09/24 职场文书
初中毕业生自我评价
2015/03/02 职场文书
超市督导岗位职责
2015/04/10 职场文书
2016全国“质量月”活动标语口号
2015/12/26 职场文书
Python 用户输入和while循环的操作
2021/05/23 Python
html5表单的required属性使用
2021/07/07 HTML / CSS