Keras loss函数剖析


Posted in Python onJuly 06, 2020

我就废话不多说了,大家还是直接看代码吧~

'''
Created on 2018-4-16
'''
def compile(
self,
optimizer, #优化器
loss, #损失函数,可以为已经定义好的loss函数名称,也可以为自己写的loss函数
metrics=None, #
sample_weight_mode=None, #如果你需要按时间步为样本赋权(2D权矩阵),将该值设为“temporal”。默认为“None”,代表按样本赋权(1D权),和fit中sample_weight在赋值样本权重中配合使用
weighted_metrics=None, 
target_tensors=None,
**kwargs #这里的设定的参数可以和后端交互。
)

实质调用的是Keras\engine\training.py 中的class Model中的def compile
一般使用model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])

# keras所有定义好的损失函数loss:
# keras\losses.py
# 有些loss函数可以使用简称:
# mse = MSE = mean_squared_error
# mae = MAE = mean_absolute_error
# mape = MAPE = mean_absolute_percentage_error
# msle = MSLE = mean_squared_logarithmic_error
# kld = KLD = kullback_leibler_divergence
# cosine = cosine_proximity
# 使用到的数学方法:
# mean:求均值
# sum:求和
# square:平方
# abs:绝对值
# clip:[裁剪替换](https://blog.csdn.net/qq1483661204/article/details)
# epsilon:1e-7
# log:以e为底
# maximum(x,y):x与 y逐位比较取其大者
# reduce_sum(x,axis):沿着某个维度求和
# l2_normalize:l2正则化
# softplus:softplus函数
# 
# import cntk as C
# 1.mean_squared_error:
#  return K.mean(K.square(y_pred - y_true), axis=-1) 
# 2.mean_absolute_error:
#  return K.mean(K.abs(y_pred - y_true), axis=-1)
# 3.mean_absolute_percentage_error:
#  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
#  return 100. * K.mean(diff, axis=-1)
# 4.mean_squared_logarithmic_error:
#  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
#  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
#  return K.mean(K.square(first_log - second_log), axis=-1)
# 5.squared_hinge:
#  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)
# 6.hinge(SVM损失函数):
#  return K.mean(K.maximum(1. - y_true * y_pred, 0.), axis=-1)
# 7.categorical_hinge:
#  pos = K.sum(y_true * y_pred, axis=-1)
#  neg = K.max((1. - y_true) * y_pred, axis=-1)
#  return K.maximum(0., neg - pos + 1.)
# 8.logcosh:
#  def _logcosh(x):
#   return x + K.softplus(-2. * x) - K.log(2.)
#  return K.mean(_logcosh(y_pred - y_true), axis=-1)
# 9.categorical_crossentropy:
#  output /= C.reduce_sum(output, axis=-1)
#  output = C.clip(output, epsilon(), 1.0 - epsilon())
#  return -sum(target * C.log(output), axis=-1)
# 10.sparse_categorical_crossentropy:
#  target = C.one_hot(target, output.shape[-1])
#  target = C.reshape(target, output.shape)
#  return categorical_crossentropy(target, output, from_logits)
# 11.binary_crossentropy:
#  return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
# 12.kullback_leibler_divergence:
#  y_true = K.clip(y_true, K.epsilon(), 1)
#  y_pred = K.clip(y_pred, K.epsilon(), 1)
#  return K.sum(y_true * K.log(y_true / y_pred), axis=-1)
# 13.poisson:
#  return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1)
# 14.cosine_proximity:
#  y_true = K.l2_normalize(y_true, axis=-1)
#  y_pred = K.l2_normalize(y_pred, axis=-1)
#  return -K.sum(y_true * y_pred, axis=-1)

补充知识:一文总结Keras的loss函数和metrics函数

Loss函数

定义:

keras.losses.mean_squared_error(y_true, y_pred)

用法很简单,就是计算均方误差平均值,例如

loss_fn = keras.losses.mean_squared_error
a1 = tf.constant([1,1,1,1])
a2 = tf.constant([2,2,2,2])
loss_fn(a1,a2)
<tf.Tensor: id=718367, shape=(), dtype=int32, numpy=1>

Metrics函数

Metrics函数也用于计算误差,但是功能比Loss函数要复杂。

定义

tf.keras.metrics.Mean(
  name='mean', dtype=None
)

这个定义过于简单,举例说明

mean_loss([1, 3, 5, 7])
mean_loss([1, 3, 5, 7])
mean_loss([1, 1, 1, 1])
mean_loss([2,2])

输出结果

<tf.Tensor: id=718929, shape=(), dtype=float32, numpy=2.857143>

这个结果等价于

np.mean([1, 3, 5, 7, 1, 3, 5, 7, 1, 1, 1, 1, 2, 2])

这是因为Metrics函数是状态函数,在神经网络训练过程中会持续不断地更新状态,是有记忆的。因为Metrics函数还带有下面几个Methods

reset_states()
Resets all of the metric state variables.
This function is called between epochs/steps, when a metric is evaluated during training.

result()
Computes and returns the metric value tensor.
Result computation is an idempotent operation that simply calculates the metric value using the state variables

update_state(
  values, sample_weight=None
)
Accumulates statistics for computing the reduction metric.

另外注意,Loss函数和Metrics函数的调用形式,

loss_fn = keras.losses.mean_squared_error mean_loss = keras.metrics.Mean()

mean_loss(1)等价于keras.metrics.Mean()(1),而不是keras.metrics.Mean(1),这个从keras.metrics.Mean函数的定义可以看出。

但是必须先令生成一个实例mean_loss=keras.metrics.Mean(),而不能直接使用keras.metrics.Mean()本身。

以上这篇Keras loss函数剖析就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
解决Pycharm中import时无法识别自己写的程序方法
May 18 Python
解决Pycharm运行时找不到文件的问题
Oct 29 Python
python使用PyQt5的简单方法
Feb 27 Python
浅谈Python编程中3个常用的数据结构和算法
Apr 30 Python
python利用7z批量解压rar的实现
Aug 07 Python
Python Collatz序列实现过程解析
Oct 12 Python
python自动识别文本编码格式代码
Dec 26 Python
python使用numpy实现直方图反向投影示例
Jan 17 Python
浅谈selenium如何应对网页内容需要鼠标滚动加载的问题
Mar 14 Python
Python局部变量与全局变量区别原理解析
Jul 14 Python
python中取绝对值简单方法总结
Jul 24 Python
python 绘制正态曲线的示例
Sep 24 Python
keras 模型参数,模型保存,中间结果输出操作
Jul 06 #Python
Python自省及反射原理实例详解
Jul 06 #Python
如何通过命令行进入python
Jul 06 #Python
解决TensorFlow调用Keras库函数存在的问题
Jul 06 #Python
python else语句在循环中的运用详解
Jul 06 #Python
Keras模型转成tensorflow的.pb操作
Jul 06 #Python
python如何进入交互模式
Jul 06 #Python
You might like
UCenter 批量添加用户的php代码
2012/07/17 PHP
php 删除目录下N分钟前创建的所有文件的实现代码
2013/08/10 PHP
php上传图片之时间戳命名(保存路径)
2014/08/15 PHP
Yii模型操作之criteria查找数据库的方法
2016/07/15 PHP
thinkPHP内置字符串截取函数用法详解
2016/11/15 PHP
golang实现php里的serialize()和unserialize()序列和反序列方法详解
2018/10/30 PHP
js资料prototype 属性
2007/03/13 Javascript
基于jquery的has()方法以及与find()方法以及filter()方法的区别详解
2013/04/26 Javascript
通过length属性判断jquery对象是否存在
2013/10/18 Javascript
浅析js设置控件的readonly与enabled属性问题
2013/12/25 Javascript
JS实现div居中示例
2014/04/17 Javascript
jQuery操作元素css样式的三种方法
2014/06/04 Javascript
JavaScript通过元素的ID和name设置样式
2014/07/08 Javascript
使用js画图之画切线
2015/01/12 Javascript
PHP+jQuery实现随意拖动层并即时保存拖动位置
2015/04/30 Javascript
浅析Node.js 中 Stream API 的使用
2015/10/23 Javascript
详解js中==与===的区别
2017/01/08 Javascript
几种响应式文字详解
2017/05/19 Javascript
JavaScript仿微信(电话)联系人列表滑动字母索引实例讲解(推荐)
2017/08/16 Javascript
点击按钮弹出模态框的一系列操作代码实例
2019/03/29 Javascript
sharp.js安装过程中遇到的问题总结
2020/04/02 Javascript
Python切片知识解析
2016/03/06 Python
Python实现k-means算法
2018/02/23 Python
基于Python log 的正确打开方式
2018/04/28 Python
详解Python用户登录接口的方法
2019/04/17 Python
tensorflow之tf.record实现存浮点数数组
2020/02/17 Python
Python龙贝格法求积分实例
2020/02/29 Python
Python爬虫自动化爬取b站实时弹幕实例方法
2021/01/26 Python
CSS3 3D旋转rotate效果实例介绍
2016/05/03 HTML / CSS
html5中localStorage本地存储的简单使用
2017/06/16 HTML / CSS
优秀本科生求职推荐信
2014/02/24 职场文书
市委常委班子党的群众路线教育实践活动整改措施
2014/10/02 职场文书
2015年质检工作总结
2015/05/04 职场文书
我的长征观后感
2015/06/09 职场文书
python使用matplotlib绘制图片时x轴的刻度处理
2021/08/30 Python
gojs实现蚂蚁线动画效果
2022/02/18 Javascript