keras中的loss、optimizer、metrics用法


Posted in Python onJune 15, 2020

用keras搭好模型架构之后的下一步,就是执行编译操作。在编译时,经常需要指定三个参数

loss

optimizer

metrics

这三个参数有两类选择:

使用字符串

使用标识符,如keras.losses,keras.optimizers,metrics包下面的函数

例如:

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
  optimizer=sgd,
  metrics=['accuracy'])

因为有时可以使用字符串,有时可以使用标识符,令人很想知道背后是如何操作的。下面分别针对optimizer,loss,metrics三种对象的获取进行研究。

optimizer

一个模型只能有一个optimizer,在执行编译的时候只能指定一个optimizer。

在keras.optimizers.py中,有一个get函数,用于根据用户传进来的optimizer参数获取优化器的实例:

def get(identifier):
 # 如果后端是tensorflow并且使用的是tensorflow自带的优化器实例,可以直接使用tensorflow原生的优化器 
 if K.backend() == 'tensorflow':
 # Wrap TF optimizer instances
 if isinstance(identifier, tf.train.Optimizer):
  return TFOptimizer(identifier)
 # 如果以json串的形式定义optimizer并进行参数配置
 if isinstance(identifier, dict):
 return deserialize(identifier)
 elif isinstance(identifier, six.string_types):
 # 如果以字符串形式指定optimizer,那么使用优化器的默认配置参数
 config = {'class_name': str(identifier), 'config': {}}
 return deserialize(config)
 if isinstance(identifier, Optimizer):
 # 如果使用keras封装的Optimizer的实例
 return identifier
 else:
 raise ValueError('Could not interpret optimizer identifier: ' +
    str(identifier))

其中,deserilize(config)函数的作用就是把optimizer反序列化制造一个实例。

loss

keras.losses函数也有一个get(identifier)方法。其中需要注意以下一点:

如果identifier是可调用的一个函数名,也就是一个自定义的损失函数,这个损失函数返回值是一个张量。这样就轻而易举的实现了自定义损失函数。除了使用str和dict类型的identifier,我们也可以直接使用keras.losses包下面的损失函数。

def get(identifier):
 if identifier is None:
 return None
 if isinstance(identifier, six.string_types):
 identifier = str(identifier)
 return deserialize(identifier)
 if isinstance(identifier, dict):
 return deserialize(identifier)
 elif callable(identifier):
 return identifier
 else:
 raise ValueError('Could not interpret '
    'loss function identifier:', identifier)

metrics

在model.compile()函数中,optimizer和loss都是单数形式,只有metrics是复数形式。因为一个模型只能指明一个optimizer和loss,却可以指明多个metrics。metrics也是三者中处理逻辑最为复杂的一个。

在keras最核心的地方keras.engine.train.py中有如下处理metrics的函数。这个函数其实就做了两件事:

根据输入的metric找到具体的metric对应的函数

计算metric张量

在寻找metric对应函数时,有两种步骤:

使用字符串形式指明准确率和交叉熵

使用keras.metrics.py中的函数

def handle_metrics(metrics, weights=None):
 metric_name_prefix = 'weighted_' if weights is not None else ''

 for metric in metrics:
 # 如果metrics是最常见的那种:accuracy,交叉熵
 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
  # custom handling of accuracy/crossentropy
  # (because of class mode duality)
  output_shape = K.int_shape(self.outputs[i])
  # 如果输出维度是1或者损失函数是二分类损失函数,那么说明是个二分类问题,应该使用二分类的accuracy和二分类的的交叉熵
  if (output_shape[-1] == 1 or
  self.loss_functions[i] == losses.binary_crossentropy):
  # case: binary accuracy/crossentropy
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.binary_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.binary_crossentropy
  # 如果损失函数是sparse_categorical_crossentropy,那么目标y_input就不是one-hot的,所以就需要使用sparse的多类准去率和sparse的多类交叉熵
  elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
  # case: categorical accuracy/crossentropy
  # with sparse targets
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.sparse_categorical_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.sparse_categorical_crossentropy
  else:
  # case: categorical accuracy/crossentropy
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.categorical_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.categorical_crossentropy
  if metric in ('accuracy', 'acc'):
   suffix = 'acc'
  elif metric in ('crossentropy', 'ce'):
   suffix = 'ce'
  weighted_metric_fn = weighted_masked_objective(metric_fn)
  metric_name = metric_name_prefix + suffix
 else:
  # 如果输入的metric不是字符串,那么就调用metrics模块获取
  metric_fn = metrics_module.get(metric)
  weighted_metric_fn = weighted_masked_objective(metric_fn)
  # Get metric name as string
  if hasattr(metric_fn, 'name'):
  metric_name = metric_fn.name
  else:
  metric_name = metric_fn.__name__
  metric_name = metric_name_prefix + metric_name

 with K.name_scope(metric_name):
  metric_result = weighted_metric_fn(y_true, y_pred,
      weights=weights,
      mask=masks[i])

 # Append to self.metrics_names, self.metric_tensors,
 # self.stateful_metric_names
 if len(self.output_names) > 1:
  metric_name = self.output_names[i] + '_' + metric_name
 # Dedupe name
 j = 1
 base_metric_name = metric_name
 while metric_name in self.metrics_names:
  metric_name = base_metric_name + '_' + str(j)
  j += 1
 self.metrics_names.append(metric_name)
 self.metrics_tensors.append(metric_result)

 # Keep track of state updates created by
 # stateful metrics (i.e. metrics layers).
 if isinstance(metric_fn, Layer) and metric_fn.stateful:
  self.stateful_metric_names.append(metric_name)
  self.stateful_metric_functions.append(metric_fn)
  self.metrics_updates += metric_fn.updates

无论怎么使用metric,最终都会变成metrics包下面的函数。当使用字符串形式指明accuracy和crossentropy时,keras会非常智能地确定应该使用metrics包下面的哪个函数。因为metrics包下的那些metric函数有不同的使用场景,例如:

有的处理的是one-hot形式的y_input(数据的类别),有的处理的是非one-hot形式的y_input

有的处理的是二分类问题的metric,有的处理的是多分类问题的metric

当使用字符串“accuracy”和“crossentropy”指明metric时,keras会根据损失函数、输出层的shape来确定具体应该使用哪个metric函数。在任何情况下,直接使用metrics下面的函数名是总不会出错的。

keras.metrics.py文件中也有一个get(identifier)函数用于获取metric函数。

def get(identifier):
 if isinstance(identifier, dict):
 config = {'class_name': str(identifier), 'config': {}}
 return deserialize(config)
 elif isinstance(identifier, six.string_types):
 return deserialize(str(identifier))
 elif callable(identifier):
 return identifier
 else:
 raise ValueError('Could not interpret '
    'metric function identifier:', identifier)

如果identifier是字符串或者字典,那么会根据identifier反序列化出一个metric函数。

如果identifier本身就是一个函数名,那么就直接返回这个函数名。这种方式就为自定义metric提供了巨大便利。

keras中的设计哲学堪称完美。

以上这篇keras中的loss、optimizer、metrics用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python + openpyxl处理excel2007文档思路以及心得
Jul 14 Python
Python中pygame的mouse鼠标事件用法实例
Nov 11 Python
Python中函数参数设置及使用的学习笔记
May 03 Python
python查看微信好友是否删除自己
Dec 19 Python
关于Python正则表达式 findall函数问题详解
Mar 22 Python
python 随机打乱 图片和对应的标签方法
Dec 14 Python
python中的函数递归和迭代原理解析
Nov 14 Python
Python任务自动化工具tox使用教程
Mar 17 Python
Django+Celery实现动态配置定时任务的方法示例
May 26 Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 Python
详解Pycharm与anaconda安装配置指南
Aug 25 Python
Python3+RIDE+RobotFramework自动化测试框架搭建过程详解
Sep 23 Python
使用keras实现Precise, Recall, F1-socre方式
Jun 15 #Python
基于python和flask实现http接口过程解析
Jun 15 #Python
基于nexus3配置Python仓库过程详解
Jun 15 #Python
Keras官方中文文档:性能评估Metrices详解
Jun 15 #Python
在keras里面实现计算f1-score的代码
Jun 15 #Python
Python流程控制语句的深入讲解
Jun 15 #Python
keras自定义损失函数并且模型加载的写法介绍
Jun 15 #Python
You might like
新手学PHP之数据库操作详解及乱码解决!
2007/01/02 PHP
PHP基于imap获取邮件实例
2014/11/11 PHP
WordPress分页伪静态加html后缀
2016/06/08 PHP
如何使用json在前后台进行数据传输实例介绍
2013/04/11 Javascript
jquery的flexigrid无法显示数据提示获取到数据
2013/07/19 Javascript
jquery实现焦点图片随机切换效果的方法
2015/03/12 Javascript
jQuery焦点图切换特效代码分享
2015/09/15 Javascript
javascript判断复选框是否选中的方法
2015/10/16 Javascript
js实现加载更多功能实例
2016/10/27 Javascript
jQuery插件HighCharts实现的2D堆条状图效果示例【附demo源码下载】
2017/03/14 Javascript
Windows下快速搭建NodeJS本地服务器的步骤
2017/08/09 NodeJs
在Vue中使用echarts的方法
2018/02/05 Javascript
解决layer.msg 不居中 ifram中的问题
2019/09/05 Javascript
[45:16]完美世界DOTA2联赛PWL S3 Magma vs Phoenix 第一场 12.12
2020/12/16 DOTA
利用python获得时间的实例说明
2013/03/25 Python
一个简单的python程序实例(通讯录)
2013/11/29 Python
python实现合并两个数组的方法
2015/05/16 Python
python图像处理之反色实现方法
2015/05/30 Python
Python实现自动添加脚本头信息的示例代码
2016/09/02 Python
Django Admin 实现外键过滤的方法
2017/09/29 Python
python 实现批量xls文件转csv文件的方法
2018/10/23 Python
python频繁写入文件时提速的方法
2019/06/26 Python
Python测试模块doctest使用解析
2019/08/10 Python
5行Python代码实现图像分割的步骤详解
2020/05/25 Python
python绘制分布折线图的示例
2020/09/24 Python
Nike台湾官方商店:Nike.com (TW)
2017/08/16 全球购物
来自全球大都市的高级街头服饰:Pegador
2018/01/03 全球购物
戴尔英国翻新电脑和电子产品:Dell UK Refurbished Computers
2019/07/30 全球购物
店长助理岗位职责
2013/12/13 职场文书
经济贸易专业自荐信
2014/06/11 职场文书
活动总结格式
2014/08/30 职场文书
五一劳动节慰问信
2015/02/14 职场文书
2015年百日安全活动总结
2015/03/26 职场文书
2015年人力资源工作总结
2015/04/08 职场文书
Nginx解决前端访问资源跨域问题的方法详解
2021/03/31 Servers
Golang使用Panic与Recover进行错误捕获
2022/03/22 Golang