Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的ORM框架中SQLAlchemy库的查询操作的教程
Apr 25 Python
详解Python编程中对Monkey Patch猴子补丁开发方式的运用
May 27 Python
Python 给屏幕打印信息加上颜色的实现方法
Apr 24 Python
Python实现的爬取百度贴吧图片功能完整示例
May 10 Python
python中break、continue 、exit() 、pass终止循环的区别详解
Jul 08 Python
手把手教你pycharm专业版安装破解教程(linux版)
Sep 26 Python
python getpass模块用法及实例详解
Oct 07 Python
python中使用you-get库批量在线下载bilibili视频的教程
Mar 10 Python
关于Python turtle库使用时坐标的确定方法
Mar 19 Python
解决django FileFIELD的编码问题
Mar 30 Python
python 多进程和协程配合使用写入数据
Oct 30 Python
python定时截屏实现
Nov 02 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
对象失去焦点时自己动提交数据的实现代码
2012/11/06 PHP
PHP获取数组中某元素的位置及array_keys函数应用
2013/01/29 PHP
PHP下用Swoole实现Actor并发模型的方法
2019/06/12 PHP
PHP如何实现阿里云短信sdk灵活应用在项目中的方法
2019/06/14 PHP
Laravel 登录后清空COOKIE的操作方法
2019/10/14 PHP
基于PHP实现用户在线状态检测
2020/11/10 PHP
发布BlueShow v1.0 图片浏览器(类似lightbox)blueshow.js 打包下载
2007/07/21 Javascript
解决Extjs4中form表单提交后无法进入success函数问题
2013/11/26 Javascript
jQuery UI插件自定义confirm确认框的方法
2015/03/20 Javascript
JavaScript动态改变div属性的实现方法
2015/07/22 Javascript
jQuery实现二级下拉菜单效果
2016/01/05 Javascript
JavaScript模拟实现封装的三种方式及写法区别
2017/10/27 Javascript
jQuery ajax读取本地json文件的实例
2017/10/31 jQuery
Vue.js做select下拉列表的实例(ul-li标签仿select标签)
2018/03/02 Javascript
Vue2.0结合webuploader实现文件分片上传功能
2018/03/09 Javascript
vue 实现全选全不选的示例代码
2018/03/29 Javascript
Vue中的$set的使用实例代码
2018/10/08 Javascript
JS/jQuery实现超简单的Table表格添加,删除行功能示例
2019/07/31 jQuery
JS对象属性的检测与获取操作实例分析
2020/03/17 Javascript
[02:05:03]完美世界DOTA2联赛循环赛 LBZS VS Matador BO2 10.28
2020/10/28 DOTA
Python访问纯真IP数据库脚本分享
2015/06/29 Python
Python检测生僻字的实现方法
2016/10/23 Python
python requests 测试代理ip是否生效
2018/07/25 Python
HTML5中语义化 b 和 i 标签
2008/10/17 HTML / CSS
canvas使用注意点总结
2013/07/19 HTML / CSS
Sneaker Studio波兰:购买运动鞋
2018/04/28 全球购物
CNC数控操作工岗位职责
2013/11/19 职场文书
物业总经理助理岗位职责
2014/06/29 职场文书
求职自我评价范文100字
2014/09/23 职场文书
2015年档案管理工作总结
2015/04/08 职场文书
开学季:喜迎新生,迎新标语少不了
2019/11/07 职场文书
2019年共青团工作条例最新版
2019/11/12 职场文书
Redis 持久化 RDB 与 AOF的执行过程
2021/11/07 Redis
Nginx虚拟主机的配置步骤过程全解
2022/03/31 Servers
简单聊聊TypeScript只读修饰符
2022/04/06 Javascript
Linux下搭建SFTP服务器的命令详解
2022/06/25 Servers