Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python开发之str.format()用法实例分析
Feb 22 Python
Python内置模块logging用法实例分析
Feb 12 Python
利用Python读取txt文档的方法讲解
Jun 23 Python
Python3连接SQLServer、Oracle、MySql的方法
Jun 28 Python
解决python中导入win32com.client出错的问题
Jul 26 Python
Django的Modelforms用法简介
Jul 27 Python
python定位xpath 节点位置的方法
Aug 27 Python
使用celery和Django处理异步任务的流程分析
Feb 19 Python
opencv+python实现均值滤波
Feb 19 Python
python+OpenCV实现图像拼接
Mar 05 Python
Django静态资源部署404问题解决方案
May 11 Python
matplotlib设置颜色、标记、线条,让你的图像更加丰富(推荐)
Sep 25 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
浅析php-fpm静态和动态执行方式的比较
2016/11/09 PHP
PHP设计模式之装饰器模式实例详解
2018/02/07 PHP
详解PHP实现支付宝小程序用户授权的工具类
2018/12/25 PHP
Prototype Template对象 学习
2009/07/19 Javascript
jquery实现select选中行、列合计示例
2014/04/25 Javascript
javascript获取网页宽高方法汇总
2015/07/19 Javascript
情人节单身的我是如何在敲完代码之后收到12束玫瑰的(javascript)
2015/08/21 Javascript
学习javascript面向对象 理解javascript原型和原型链
2016/01/04 Javascript
jQuery ajax提交Form表单实例(附demo源码)
2016/04/06 Javascript
深入浅析JavaScript函数前面的加号和叹号
2016/07/09 Javascript
js实现添加可信站点、修改activex安全设置,禁用弹出窗口阻止程序
2016/08/17 Javascript
JS 循环li添加点击事件 (闭包的应用)
2016/12/10 Javascript
浅谈jQuery操作类数组的工具方法
2016/12/23 Javascript
基于Vuejs和Element的注册插件的编写方法
2017/07/03 Javascript
微信小程序实践之动态控制组件的显示/隐藏功能
2018/07/18 Javascript
详解Vue取消eslint语法限制
2018/08/04 Javascript
JS中数组与对象的遍历方法实例小结
2018/08/14 Javascript
iview的table组件自带的过滤器实现
2019/07/12 Javascript
Vue之Mixins(混入)的使用方法
2019/09/24 Javascript
浅谈webpack和webpack-cli模块源码分析
2020/01/19 Javascript
python中stdout输出不缓存的设置方法
2014/05/29 Python
Python爬虫使用脚本登录Github并查看信息
2018/07/16 Python
通过python爬虫赚钱的方法
2019/01/29 Python
python3使用GUI统计代码量
2019/09/18 Python
Python+Redis实现布隆过滤器
2019/12/08 Python
Python和Sublime整合过程图示
2019/12/25 Python
如何在 Django 模板中输出 "{{"
2020/01/24 Python
TensorFlow 输出checkpoint 中的变量名与变量值方式
2020/02/11 Python
基于Python测试程序是否有错误
2020/05/16 Python
英国皇室御用百货:福南梅森(Fortnum & Mason)
2017/12/03 全球购物
Notino罗马尼亚网站:购买香水和化妆品
2019/07/20 全球购物
几个数据库方面的面试题
2016/07/01 面试题
汉语言文学毕业生自荐信范文
2014/03/24 职场文书
群众路线教育实践活动心得体会(教师)
2014/10/31 职场文书
中小学教师继续教育心得体会
2016/01/19 职场文书
2019通用版导游词范本!
2019/08/07 职场文书