基于Keras的格式化输出Loss实现方式


Posted in Python onJune 17, 2020

在win7 64位,Anaconda安装的Python3.6.1下安装的TensorFlow与Keras,Keras的backend为TensorFlow。在运行Mask R-CNN时,在进行调试时想知道PyCharm (Python IDE)底部窗口输出的Loss格式是在哪里定义的,如下图红框中所示:

基于Keras的格式化输出Loss实现方式

图1 训练过程的Loss格式化输出

在上图红框中,Loss的输出格式是在哪里定义的呢?有一点是明确的,即上图红框中的内容是在训练的时候输出的。那么先来看一下Mask R-CNN的训练过程。Keras以Numpy数组作为输入数据和标签的数据类型。训练模型一般使用 fit 函数。然而由于Mask R-CNN训练数据巨大,不能一次性全部载入,否则太消耗内存。于是采用生成器的方式一次载入一个batch的数据,而且是在用到这个batch的数据才开始载入的,那么它的训练函数如下:

self.keras_model.fit_generator(
   train_generator,
   initial_epoch=self.epoch,
   epochs=epochs,
   steps_per_epoch=self.config.STEPS_PER_EPOCH,
   callbacks=callbacks,
   validation_data=val_generator,
   validation_steps=self.config.VALIDATION_STEPS,
   max_queue_size=100,
   workers=workers,
   use_multiprocessing=False,
  )

这里训练模型的函数相应的为 fit_generator 函数。注意其中的参数callbacks=callbacks,这个参数在输出红框中的内容起到了关键性的作用。下面看一下callbacks的值:

# Callbacks
  callbacks = [
   keras.callbacks.TensorBoard(log_dir=self.log_dir,
          histogram_freq=0, write_graph=True, write_images=False),
   keras.callbacks.ModelCheckpoint(self.checkpoint_path,
           verbose=0, save_weights_only=True),
  ]

在输出红框中的内容所需的数据均保存在self.log_dir下。然后调试进入self.keras_model.fit_generator函数,进入keras,legacy.interfaces的legacy_support(func)函数,如下所示:

def legacy_support(func):
  @six.wraps(func)
  def wrapper(*args, **kwargs):
   if object_type == 'class':
    object_name = args[0].__class__.__name__
   else:
    object_name = func.__name__
   if preprocessor:
    args, kwargs, converted = preprocessor(args, kwargs)
   else:
    converted = []
   if check_positional_args:
    if len(args) > len(allowed_positional_args) + 1:
     raise TypeError('`' + object_name +
         '` can accept only ' +
         str(len(allowed_positional_args)) +
         ' positional arguments ' +
         str(tuple(allowed_positional_args)) +
         ', but you passed the following '
         'positional arguments: ' +
         str(list(args[1:])))
   for key in value_conversions:
    if key in kwargs:
     old_value = kwargs[key]
     if old_value in value_conversions[key]:
      kwargs[key] = value_conversions[key][old_value]
   for old_name, new_name in conversions:
    if old_name in kwargs:
     value = kwargs.pop(old_name)
     if new_name in kwargs:
      raise_duplicate_arg_error(old_name, new_name)
     kwargs[new_name] = value
     converted.append((new_name, old_name))
   if converted:
    signature = '`' + object_name + '('
    for i, value in enumerate(args[1:]):
     if isinstance(value, six.string_types):
      signature += '"' + value + '"'
     else:
      if isinstance(value, np.ndarray):
       str_val = 'array'
      else:
       str_val = str(value)
      if len(str_val) > 10:
       str_val = str_val[:10] + '...'
      signature += str_val
     if i < len(args[1:]) - 1 or kwargs:
      signature += ', '
    for i, (name, value) in enumerate(kwargs.items()):
     signature += name + '='
     if isinstance(value, six.string_types):
      signature += '"' + value + '"'
     else:
      if isinstance(value, np.ndarray):
       str_val = 'array'
      else:
       str_val = str(value)
      if len(str_val) > 10:
       str_val = str_val[:10] + '...'
      signature += str_val
     if i < len(kwargs) - 1:
      signature += ', '
    signature += ')`'
    warnings.warn('Update your `' + object_name +
        '` call to the Keras 2 API: ' + signature, stacklevel=2)
   return func(*args, **kwargs)
  wrapper._original_function = func
  return wrapper
 return legacy_support

在上述代码的倒数第4行的return func(*args, **kwargs)处返回func,func为fit_generator函数,现调试进入fit_generator函数,该函数定义在keras.engine.training模块内的fit_generator函数,调试进入函数callbacks.on_epoch_begin(epoch),如下所示:

# Construct epoch logs.
   epoch_logs = {}
   while epoch < epochs:
    for m in self.stateful_metric_functions:
     m.reset_states()
    callbacks.on_epoch_begin(epoch)

调试进入到callbacks.on_epoch_begin(epoch)函数,进入on_epoch_begin函数,如下所示:

def on_epoch_begin(self, epoch, logs=None):
  """Called at the start of an epoch.
  # Arguments
   epoch: integer, index of epoch.
   logs: dictionary of logs.
  """
  logs = logs or {}
  for callback in self.callbacks:
   callback.on_epoch_begin(epoch, logs)
  self._delta_t_batch = 0.
  self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
  self._delta_ts_batch_end = deque([], maxlen=self.queue_length)

在上述函数on_epoch_begin中调试进入callback.on_epoch_begin(epoch, logs)函数,转到类ProgbarLogger(Callback)中定义的on_epoch_begin函数,如下所示:

class ProgbarLogger(Callback):
 """Callback that prints metrics to stdout.
 # Arguments
  count_mode: One of "steps" or "samples".
   Whether the progress bar should
   count samples seen or steps (batches) seen.
  stateful_metrics: Iterable of string names of metrics that
   should *not* be averaged over an epoch.
   Metrics in this list will be logged as-is.
   All others will be averaged over time (e.g. loss, etc).
 # Raises
  ValueError: In case of invalid `count_mode`.
 """
 
 def __init__(self, count_mode='samples',
     stateful_metrics=None):
  super(ProgbarLogger, self).__init__()
  if count_mode == 'samples':
   self.use_steps = False
  elif count_mode == 'steps':
   self.use_steps = True
  else:
   raise ValueError('Unknown `count_mode`: ' + str(count_mode))
  if stateful_metrics:
   self.stateful_metrics = set(stateful_metrics)
  else:
   self.stateful_metrics = set()
 
 def on_train_begin(self, logs=None):
  self.verbose = self.params['verbose']
  self.epochs = self.params['epochs']
 
 def on_epoch_begin(self, epoch, logs=None):
  if self.verbose:
   print('Epoch %d/%d' % (epoch + 1, self.epochs))
   if self.use_steps:
    target = self.params['steps']
   else:
    target = self.params['samples']
   self.target = target
   self.progbar = Progbar(target=self.target,
         verbose=self.verbose,
         stateful_metrics=self.stateful_metrics)
  self.seen = 0

在上述代码的

print('Epoch %d/%d' % (epoch + 1, self.epochs))

输出

Epoch 1/40(如红框中所示内容的第一行)。

然后返回到keras.engine.training模块内的fit_generator函数,执行到self.train_on_batch函数,如下所示:

outs = self.train_on_batch(x, y,
     sample_weight=sample_weight,
     class_weight=class_weight)
 
     if not isinstance(outs, list):
      outs = [outs]
     for l, o in zip(out_labels, outs):
      batch_logs[l] = o
 
     callbacks.on_batch_end(batch_index, batch_logs)
 
     batch_index += 1
     steps_done += 1

调试进入上述代码中的callbacks.on_batch_end(batch_index, batch_logs)函数,进入到on_batch_end函数后,该函数的定义如下所示:

def on_batch_end(self, batch, logs=None):
  """Called at the end of a batch.
  # Arguments
   batch: integer, index of batch within the current epoch.
   logs: dictionary of logs.
  """
  logs = logs or {}
  if not hasattr(self, '_t_enter_batch'):
   self._t_enter_batch = time.time()
  self._delta_t_batch = time.time() - self._t_enter_batch
  t_before_callbacks = time.time()
  for callback in self.callbacks:
   callback.on_batch_end(batch, logs)
  self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
  delta_t_median = np.median(self._delta_ts_batch_end)
  if (self._delta_t_batch > 0. and
   (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
   warnings.warn('Method on_batch_end() is slow compared '
       'to the batch update (%f). Check your callbacks.'
       % delta_t_median)

接着继续调试进入上述代码中的callback.on_batch_end(batch, logs)函数,进入到在类中ProgbarLogger(Callback)定义的on_batch_end函数,如下所示:

def on_batch_end(self, batch, logs=None):
  logs = logs or {}
  batch_size = logs.get('size', 0)
  if self.use_steps:
   self.seen += 1
  else:
   self.seen += batch_size
 
  for k in self.params['metrics']:
   if k in logs:
    self.log_values.append((k, logs[k]))
 
  # Skip progbar update for the last batch;
  # will be handled by on_epoch_end.
  if self.verbose and self.seen < self.target:
   self.progbar.update(self.seen, self.log_values)

然后执行到上述代码的最后一行self.progbar.update(self.seen, self.log_values),调试进入update函数,该函数定义在模块keras.utils.generic_utils中的类Progbar(object)定义的函数。类的定义及方法如下所示:

class Progbar(object):
 """Displays a progress bar.
 # Arguments
  target: Total number of steps expected, None if unknown.
  width: Progress bar width on screen.
  verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
  stateful_metrics: Iterable of string names of metrics that
   should *not* be averaged over time. Metrics in this list
   will be displayed as-is. All others will be averaged
   by the progbar before display.
  interval: Minimum visual progress update interval (in seconds).
 """
 
 def __init__(self, target, width=30, verbose=1, interval=0.05,
     stateful_metrics=None):
  self.target = target
  self.width = width
  self.verbose = verbose
  self.interval = interval
  if stateful_metrics:
   self.stateful_metrics = set(stateful_metrics)
  else:
   self.stateful_metrics = set()
 
  self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
         sys.stdout.isatty()) or
         'ipykernel' in sys.modules)
  self._total_width = 0
  self._seen_so_far = 0
  self._values = collections.OrderedDict()
  self._start = time.time()
  self._last_update = 0
 
 def update(self, current, values=None):
  """Updates the progress bar.
  # Arguments
   current: Index of current step.
   values: List of tuples:
    `(name, value_for_last_step)`.
    If `name` is in `stateful_metrics`,
    `value_for_last_step` will be displayed as-is.
    Else, an average of the metric over time will be displayed.
  """
  values = values or []
  for k, v in values:
   if k not in self.stateful_metrics:
    if k not in self._values:
     self._values[k] = [v * (current - self._seen_so_far),
          current - self._seen_so_far]
    else:
     self._values[k][0] += v * (current - self._seen_so_far)
     self._values[k][1] += (current - self._seen_so_far)
   else:
    # Stateful metrics output a numeric value. This representation
    # means "take an average from a single value" but keeps the
    # numeric formatting.
    self._values[k] = [v, 1]
  self._seen_so_far = current
 
  now = time.time()
  info = ' - %.0fs' % (now - self._start)
  if self.verbose == 1:
   if (now - self._last_update < self.interval and
     self.target is not None and current < self.target):
    return
 
   prev_total_width = self._total_width
   if self._dynamic_display:
    sys.stdout.write('\b' * prev_total_width)
    sys.stdout.write('\r')
   else:
    sys.stdout.write('\n')
 
   if self.target is not None:
    numdigits = int(np.floor(np.log10(self.target))) + 1
    barstr = '%%%dd/%d [' % (numdigits, self.target)
    bar = barstr % current
    prog = float(current) / self.target
    prog_width = int(self.width * prog)
    if prog_width > 0:
     bar += ('=' * (prog_width - 1))
     if current < self.target:
      bar += '>'
     else:
      bar += '='
    bar += ('.' * (self.width - prog_width))
    bar += ']'
   else:
    bar = '%7d/Unknown' % current
 
   self._total_width = len(bar)
   sys.stdout.write(bar)
 
   if current:
    time_per_unit = (now - self._start) / current
   else:
    time_per_unit = 0
   if self.target is not None and current < self.target:
    eta = time_per_unit * (self.target - current)
    if eta > 3600:
     eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60)
    elif eta > 60:
     eta_format = '%d:%02d' % (eta // 60, eta % 60)
    else:
     eta_format = '%ds' % eta
 
    info = ' - ETA: %s' % eta_format
   else:
    if time_per_unit >= 1:
     info += ' %.0fs/step' % time_per_unit
    elif time_per_unit >= 1e-3:
     info += ' %.0fms/step' % (time_per_unit * 1e3)
    else:
     info += ' %.0fus/step' % (time_per_unit * 1e6)
 
   for k in self._values:
    info += ' - %s:' % k
    if isinstance(self._values[k], list):
     avg = np.mean(
      self._values[k][0] / max(1, self._values[k][1]))
     if abs(avg) > 1e-3:
      info += ' %.4f' % avg
     else:
      info += ' %.4e' % avg
    else:
     info += ' %s' % self._values[k]
 
   self._total_width += len(info)
   if prev_total_width > self._total_width:
    info += (' ' * (prev_total_width - self._total_width))
 
   if self.target is not None and current >= self.target:
    info += '\n'
 
   sys.stdout.write(info)
   sys.stdout.flush()
 
  elif self.verbose == 2:
   if self.target is None or current >= self.target:
    for k in self._values:
     info += ' - %s:' % k
     avg = np.mean(
      self._values[k][0] / max(1, self._values[k][1]))
     if avg > 1e-3:
      info += ' %.4f' % avg
     else:
      info += ' %.4e' % avg
    info += '\n'
 
    sys.stdout.write(info)
    sys.stdout.flush()
 
  self._last_update = now
 
 def add(self, n, values=None):
  self.update(self._seen_so_far + n, values)

重点是上述代码中的update(self, current, values=None)函数,在该函数内设置断点,即可调入该函数。下面重点分析上述代码中的几个输出条目:

1. sys.stdout.write('\n') #换行

2. sys.stdout.write('bar') #输出 [..................],其中bar= [..................];

3. sys.stdout.write(info) #输出loss格式,其中info='- ETA:...';

4. sys.stdout.flush() #刷新缓存,立即得到输出。

通过对Mask R-CNN代码的调试分析可知,图1中的红框中的训练过程中的Loss格式化输出是由built-in模块实现的。若想得到类似的格式化输出,关键在self.keras_model.fit_generator函数中传入callbacks参数和callbacks中内容的定义。

以上这篇基于Keras的格式化输出Loss实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详细介绍Python语言中的按位运算符
Nov 26 Python
利用Python和OpenCV库将URL转换为OpenCV格式的方法
Mar 27 Python
使用Python的Flask框架实现视频的流媒体传输
Mar 31 Python
详解python之简单主机批量管理工具
Jan 27 Python
实践Vim配置python开发环境
Jul 02 Python
使用python的pandas为你的股票绘制趋势图
Jun 26 Python
python数据爬下来保存的位置
Feb 17 Python
Python调用shell命令常用方法(4种)
May 11 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 Python
Python Merge函数原理及用法解析
Sep 16 Python
Python+unittest+DDT实现数据驱动测试
Nov 30 Python
OpenCV绘制圆端矩形的示例代码
Aug 30 Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 #Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
You might like
php多重接口的实现方法
2015/06/20 PHP
详解PHP数组赋值方法
2015/11/07 PHP
Yii2前后台分离及migrate使用(七)
2016/05/04 PHP
脚本吧 - 幻宇工作室用到js,超强推荐share.js
2006/12/23 Javascript
静态页面下用javascript操作ACCESS数据库(读增改删)的代码
2007/05/14 Javascript
Javascript事件实例详解
2013/11/06 Javascript
js实现交换运动效果的方法
2015/04/10 Javascript
Javascript实现Array和String互转换的方法
2015/12/21 Javascript
jquery悬浮提示框完整实例
2016/01/13 Javascript
jQuery图片旋转插件jQueryRotate.js用法实例(附demo下载)
2016/01/21 Javascript
网页中JS函数自动执行常用三种方法
2016/03/30 Javascript
浅谈JavaScript的push(),pop(),concat()方法
2016/06/03 Javascript
AngularJS中run方法的巧妙运用
2017/01/04 Javascript
jQuery实现限制文本框的输入长度
2017/01/11 Javascript
JavaScript html5利用FileReader实现上传功能
2020/03/27 Javascript
laravel5.4+vue+element简单搭建的示例代码
2017/08/29 Javascript
Node.js + express实现上传大文件的方法分析【图片、文本文件】
2019/03/14 Javascript
JavaScript跳出循环的三种方法(break, return, continue)
2019/07/30 Javascript
react+antd 递归实现树状目录操作
2020/11/02 Javascript
Python的高级Git库 Gittle
2014/09/22 Python
Python模拟登陆淘宝并统计淘宝消费情况的代码实例分享
2016/07/04 Python
pyspark.sql.DataFrame与pandas.DataFrame之间的相互转换实例
2018/08/02 Python
python写日志文件操作类与应用示例
2019/07/01 Python
Python库安装速度过慢解决方案
2020/07/14 Python
利用html5的websocket实现websocket聊天室
2013/12/12 HTML / CSS
利用HTML5 Canvas API绘制矩形的超级攻略
2016/03/21 HTML / CSS
html5唤醒APP小记
2019/03/27 HTML / CSS
澳大利亚窗帘商店:Curtain Wonderland
2019/12/01 全球购物
Woods官网:加拿大最古老、最受尊敬的户外品牌之一
2020/09/12 全球购物
北大自主招生自荐信
2013/10/19 职场文书
优秀经理获奖感言
2014/03/04 职场文书
工程安全生产协议书
2014/11/21 职场文书
交通事故死亡赔偿协议书
2014/12/03 职场文书
幼儿园2015年度工作总结
2015/04/01 职场文书
关于应聘教师的自荐信
2016/01/28 职场文书
SQL IDENTITY_INSERT作用案例详解
2021/08/23 MySQL