基于Keras的格式化输出Loss实现方式


Posted in Python onJune 17, 2020

在win7 64位,Anaconda安装的Python3.6.1下安装的TensorFlow与Keras,Keras的backend为TensorFlow。在运行Mask R-CNN时,在进行调试时想知道PyCharm (Python IDE)底部窗口输出的Loss格式是在哪里定义的,如下图红框中所示:

基于Keras的格式化输出Loss实现方式

图1 训练过程的Loss格式化输出

在上图红框中,Loss的输出格式是在哪里定义的呢?有一点是明确的,即上图红框中的内容是在训练的时候输出的。那么先来看一下Mask R-CNN的训练过程。Keras以Numpy数组作为输入数据和标签的数据类型。训练模型一般使用 fit 函数。然而由于Mask R-CNN训练数据巨大,不能一次性全部载入,否则太消耗内存。于是采用生成器的方式一次载入一个batch的数据,而且是在用到这个batch的数据才开始载入的,那么它的训练函数如下:

self.keras_model.fit_generator(
   train_generator,
   initial_epoch=self.epoch,
   epochs=epochs,
   steps_per_epoch=self.config.STEPS_PER_EPOCH,
   callbacks=callbacks,
   validation_data=val_generator,
   validation_steps=self.config.VALIDATION_STEPS,
   max_queue_size=100,
   workers=workers,
   use_multiprocessing=False,
  )

这里训练模型的函数相应的为 fit_generator 函数。注意其中的参数callbacks=callbacks,这个参数在输出红框中的内容起到了关键性的作用。下面看一下callbacks的值:

# Callbacks
  callbacks = [
   keras.callbacks.TensorBoard(log_dir=self.log_dir,
          histogram_freq=0, write_graph=True, write_images=False),
   keras.callbacks.ModelCheckpoint(self.checkpoint_path,
           verbose=0, save_weights_only=True),
  ]

在输出红框中的内容所需的数据均保存在self.log_dir下。然后调试进入self.keras_model.fit_generator函数,进入keras,legacy.interfaces的legacy_support(func)函数,如下所示:

def legacy_support(func):
  @six.wraps(func)
  def wrapper(*args, **kwargs):
   if object_type == 'class':
    object_name = args[0].__class__.__name__
   else:
    object_name = func.__name__
   if preprocessor:
    args, kwargs, converted = preprocessor(args, kwargs)
   else:
    converted = []
   if check_positional_args:
    if len(args) > len(allowed_positional_args) + 1:
     raise TypeError('`' + object_name +
         '` can accept only ' +
         str(len(allowed_positional_args)) +
         ' positional arguments ' +
         str(tuple(allowed_positional_args)) +
         ', but you passed the following '
         'positional arguments: ' +
         str(list(args[1:])))
   for key in value_conversions:
    if key in kwargs:
     old_value = kwargs[key]
     if old_value in value_conversions[key]:
      kwargs[key] = value_conversions[key][old_value]
   for old_name, new_name in conversions:
    if old_name in kwargs:
     value = kwargs.pop(old_name)
     if new_name in kwargs:
      raise_duplicate_arg_error(old_name, new_name)
     kwargs[new_name] = value
     converted.append((new_name, old_name))
   if converted:
    signature = '`' + object_name + '('
    for i, value in enumerate(args[1:]):
     if isinstance(value, six.string_types):
      signature += '"' + value + '"'
     else:
      if isinstance(value, np.ndarray):
       str_val = 'array'
      else:
       str_val = str(value)
      if len(str_val) > 10:
       str_val = str_val[:10] + '...'
      signature += str_val
     if i < len(args[1:]) - 1 or kwargs:
      signature += ', '
    for i, (name, value) in enumerate(kwargs.items()):
     signature += name + '='
     if isinstance(value, six.string_types):
      signature += '"' + value + '"'
     else:
      if isinstance(value, np.ndarray):
       str_val = 'array'
      else:
       str_val = str(value)
      if len(str_val) > 10:
       str_val = str_val[:10] + '...'
      signature += str_val
     if i < len(kwargs) - 1:
      signature += ', '
    signature += ')`'
    warnings.warn('Update your `' + object_name +
        '` call to the Keras 2 API: ' + signature, stacklevel=2)
   return func(*args, **kwargs)
  wrapper._original_function = func
  return wrapper
 return legacy_support

在上述代码的倒数第4行的return func(*args, **kwargs)处返回func,func为fit_generator函数,现调试进入fit_generator函数,该函数定义在keras.engine.training模块内的fit_generator函数,调试进入函数callbacks.on_epoch_begin(epoch),如下所示:

# Construct epoch logs.
   epoch_logs = {}
   while epoch < epochs:
    for m in self.stateful_metric_functions:
     m.reset_states()
    callbacks.on_epoch_begin(epoch)

调试进入到callbacks.on_epoch_begin(epoch)函数,进入on_epoch_begin函数,如下所示:

def on_epoch_begin(self, epoch, logs=None):
  """Called at the start of an epoch.
  # Arguments
   epoch: integer, index of epoch.
   logs: dictionary of logs.
  """
  logs = logs or {}
  for callback in self.callbacks:
   callback.on_epoch_begin(epoch, logs)
  self._delta_t_batch = 0.
  self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
  self._delta_ts_batch_end = deque([], maxlen=self.queue_length)

在上述函数on_epoch_begin中调试进入callback.on_epoch_begin(epoch, logs)函数,转到类ProgbarLogger(Callback)中定义的on_epoch_begin函数,如下所示:

class ProgbarLogger(Callback):
 """Callback that prints metrics to stdout.
 # Arguments
  count_mode: One of "steps" or "samples".
   Whether the progress bar should
   count samples seen or steps (batches) seen.
  stateful_metrics: Iterable of string names of metrics that
   should *not* be averaged over an epoch.
   Metrics in this list will be logged as-is.
   All others will be averaged over time (e.g. loss, etc).
 # Raises
  ValueError: In case of invalid `count_mode`.
 """
 
 def __init__(self, count_mode='samples',
     stateful_metrics=None):
  super(ProgbarLogger, self).__init__()
  if count_mode == 'samples':
   self.use_steps = False
  elif count_mode == 'steps':
   self.use_steps = True
  else:
   raise ValueError('Unknown `count_mode`: ' + str(count_mode))
  if stateful_metrics:
   self.stateful_metrics = set(stateful_metrics)
  else:
   self.stateful_metrics = set()
 
 def on_train_begin(self, logs=None):
  self.verbose = self.params['verbose']
  self.epochs = self.params['epochs']
 
 def on_epoch_begin(self, epoch, logs=None):
  if self.verbose:
   print('Epoch %d/%d' % (epoch + 1, self.epochs))
   if self.use_steps:
    target = self.params['steps']
   else:
    target = self.params['samples']
   self.target = target
   self.progbar = Progbar(target=self.target,
         verbose=self.verbose,
         stateful_metrics=self.stateful_metrics)
  self.seen = 0

在上述代码的

print('Epoch %d/%d' % (epoch + 1, self.epochs))

输出

Epoch 1/40(如红框中所示内容的第一行)。

然后返回到keras.engine.training模块内的fit_generator函数,执行到self.train_on_batch函数,如下所示:

outs = self.train_on_batch(x, y,
     sample_weight=sample_weight,
     class_weight=class_weight)
 
     if not isinstance(outs, list):
      outs = [outs]
     for l, o in zip(out_labels, outs):
      batch_logs[l] = o
 
     callbacks.on_batch_end(batch_index, batch_logs)
 
     batch_index += 1
     steps_done += 1

调试进入上述代码中的callbacks.on_batch_end(batch_index, batch_logs)函数,进入到on_batch_end函数后,该函数的定义如下所示:

def on_batch_end(self, batch, logs=None):
  """Called at the end of a batch.
  # Arguments
   batch: integer, index of batch within the current epoch.
   logs: dictionary of logs.
  """
  logs = logs or {}
  if not hasattr(self, '_t_enter_batch'):
   self._t_enter_batch = time.time()
  self._delta_t_batch = time.time() - self._t_enter_batch
  t_before_callbacks = time.time()
  for callback in self.callbacks:
   callback.on_batch_end(batch, logs)
  self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
  delta_t_median = np.median(self._delta_ts_batch_end)
  if (self._delta_t_batch > 0. and
   (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
   warnings.warn('Method on_batch_end() is slow compared '
       'to the batch update (%f). Check your callbacks.'
       % delta_t_median)

接着继续调试进入上述代码中的callback.on_batch_end(batch, logs)函数,进入到在类中ProgbarLogger(Callback)定义的on_batch_end函数,如下所示:

def on_batch_end(self, batch, logs=None):
  logs = logs or {}
  batch_size = logs.get('size', 0)
  if self.use_steps:
   self.seen += 1
  else:
   self.seen += batch_size
 
  for k in self.params['metrics']:
   if k in logs:
    self.log_values.append((k, logs[k]))
 
  # Skip progbar update for the last batch;
  # will be handled by on_epoch_end.
  if self.verbose and self.seen < self.target:
   self.progbar.update(self.seen, self.log_values)

然后执行到上述代码的最后一行self.progbar.update(self.seen, self.log_values),调试进入update函数,该函数定义在模块keras.utils.generic_utils中的类Progbar(object)定义的函数。类的定义及方法如下所示:

class Progbar(object):
 """Displays a progress bar.
 # Arguments
  target: Total number of steps expected, None if unknown.
  width: Progress bar width on screen.
  verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
  stateful_metrics: Iterable of string names of metrics that
   should *not* be averaged over time. Metrics in this list
   will be displayed as-is. All others will be averaged
   by the progbar before display.
  interval: Minimum visual progress update interval (in seconds).
 """
 
 def __init__(self, target, width=30, verbose=1, interval=0.05,
     stateful_metrics=None):
  self.target = target
  self.width = width
  self.verbose = verbose
  self.interval = interval
  if stateful_metrics:
   self.stateful_metrics = set(stateful_metrics)
  else:
   self.stateful_metrics = set()
 
  self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
         sys.stdout.isatty()) or
         'ipykernel' in sys.modules)
  self._total_width = 0
  self._seen_so_far = 0
  self._values = collections.OrderedDict()
  self._start = time.time()
  self._last_update = 0
 
 def update(self, current, values=None):
  """Updates the progress bar.
  # Arguments
   current: Index of current step.
   values: List of tuples:
    `(name, value_for_last_step)`.
    If `name` is in `stateful_metrics`,
    `value_for_last_step` will be displayed as-is.
    Else, an average of the metric over time will be displayed.
  """
  values = values or []
  for k, v in values:
   if k not in self.stateful_metrics:
    if k not in self._values:
     self._values[k] = [v * (current - self._seen_so_far),
          current - self._seen_so_far]
    else:
     self._values[k][0] += v * (current - self._seen_so_far)
     self._values[k][1] += (current - self._seen_so_far)
   else:
    # Stateful metrics output a numeric value. This representation
    # means "take an average from a single value" but keeps the
    # numeric formatting.
    self._values[k] = [v, 1]
  self._seen_so_far = current
 
  now = time.time()
  info = ' - %.0fs' % (now - self._start)
  if self.verbose == 1:
   if (now - self._last_update < self.interval and
     self.target is not None and current < self.target):
    return
 
   prev_total_width = self._total_width
   if self._dynamic_display:
    sys.stdout.write('\b' * prev_total_width)
    sys.stdout.write('\r')
   else:
    sys.stdout.write('\n')
 
   if self.target is not None:
    numdigits = int(np.floor(np.log10(self.target))) + 1
    barstr = '%%%dd/%d [' % (numdigits, self.target)
    bar = barstr % current
    prog = float(current) / self.target
    prog_width = int(self.width * prog)
    if prog_width > 0:
     bar += ('=' * (prog_width - 1))
     if current < self.target:
      bar += '>'
     else:
      bar += '='
    bar += ('.' * (self.width - prog_width))
    bar += ']'
   else:
    bar = '%7d/Unknown' % current
 
   self._total_width = len(bar)
   sys.stdout.write(bar)
 
   if current:
    time_per_unit = (now - self._start) / current
   else:
    time_per_unit = 0
   if self.target is not None and current < self.target:
    eta = time_per_unit * (self.target - current)
    if eta > 3600:
     eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60)
    elif eta > 60:
     eta_format = '%d:%02d' % (eta // 60, eta % 60)
    else:
     eta_format = '%ds' % eta
 
    info = ' - ETA: %s' % eta_format
   else:
    if time_per_unit >= 1:
     info += ' %.0fs/step' % time_per_unit
    elif time_per_unit >= 1e-3:
     info += ' %.0fms/step' % (time_per_unit * 1e3)
    else:
     info += ' %.0fus/step' % (time_per_unit * 1e6)
 
   for k in self._values:
    info += ' - %s:' % k
    if isinstance(self._values[k], list):
     avg = np.mean(
      self._values[k][0] / max(1, self._values[k][1]))
     if abs(avg) > 1e-3:
      info += ' %.4f' % avg
     else:
      info += ' %.4e' % avg
    else:
     info += ' %s' % self._values[k]
 
   self._total_width += len(info)
   if prev_total_width > self._total_width:
    info += (' ' * (prev_total_width - self._total_width))
 
   if self.target is not None and current >= self.target:
    info += '\n'
 
   sys.stdout.write(info)
   sys.stdout.flush()
 
  elif self.verbose == 2:
   if self.target is None or current >= self.target:
    for k in self._values:
     info += ' - %s:' % k
     avg = np.mean(
      self._values[k][0] / max(1, self._values[k][1]))
     if avg > 1e-3:
      info += ' %.4f' % avg
     else:
      info += ' %.4e' % avg
    info += '\n'
 
    sys.stdout.write(info)
    sys.stdout.flush()
 
  self._last_update = now
 
 def add(self, n, values=None):
  self.update(self._seen_so_far + n, values)

重点是上述代码中的update(self, current, values=None)函数,在该函数内设置断点,即可调入该函数。下面重点分析上述代码中的几个输出条目:

1. sys.stdout.write('\n') #换行

2. sys.stdout.write('bar') #输出 [..................],其中bar= [..................];

3. sys.stdout.write(info) #输出loss格式,其中info='- ETA:...';

4. sys.stdout.flush() #刷新缓存,立即得到输出。

通过对Mask R-CNN代码的调试分析可知,图1中的红框中的训练过程中的Loss格式化输出是由built-in模块实现的。若想得到类似的格式化输出,关键在self.keras_model.fit_generator函数中传入callbacks参数和callbacks中内容的定义。

以上这篇基于Keras的格式化输出Loss实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python线程池的实现实例
Nov 18 Python
小结Python用fork来创建子进程注意事项
Jul 03 Python
利用Python如何生成hash值示例详解
Dec 20 Python
Flask 让jsonify返回的json串支持中文显示的方法
Mar 26 Python
windows下搭建python scrapy爬虫框架步骤
Dec 23 Python
Python一个简单的通信程序(客户端 服务器)
Mar 06 Python
python pytest进阶之fixture详解
Jun 27 Python
python批量图片处理简单示例
Aug 06 Python
python-OpenCV 实现将数组转换成灰度图和彩图
Jan 09 Python
python字符串,元组,列表,字典互转代码实例详解
Feb 14 Python
Python爬取新型冠状病毒“谣言”新闻进行数据分析
Feb 16 Python
TensorFlow使用Graph的基本操作的实现
Apr 22 Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 #Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
You might like
ThinkPHP视图查询详解
2014/06/30 PHP
php中__toString()方法用法示例
2016/12/07 PHP
php创建类并调用的实例方法
2019/09/25 PHP
Eval and new funciton not the same thing
2012/12/27 Javascript
JS链式调用的实现方法
2013/03/07 Javascript
js数组去重的常用方法总结
2014/01/24 Javascript
js字符串日期yyyy-MM-dd转化为date示例代码
2014/03/06 Javascript
IE中JS跳转丢失referrer问题的2个解决方法
2014/07/18 Javascript
分享一个自己写的简单的javascript分页组件
2015/02/15 Javascript
zepto中使用swipe.js制作轮播图附swipeUp,swipeDown不起效果问题
2015/08/27 Javascript
jquery读写cookie操作实例分析
2015/12/24 Javascript
jQuery实现的简单悬浮层功能完整实例
2017/01/23 Javascript
jQuery 利用ztree实现树形表格的实例代码
2017/09/27 jQuery
JavaScript 中的12种循环遍历方法【总结】
2018/05/31 Javascript
JS遍历JSON数组及获取JSON数组长度操作示例【测试可用】
2018/12/12 Javascript
JavaScript实现小球沿正弦曲线运动
2020/09/07 Javascript
js实现类似iphone的网页滑屏解锁功能示例【附源码下载】
2019/06/10 Javascript
微信小程序 scroll-view 实现锚点跳转功能
2019/12/12 Javascript
微信小程序 SOTER 生物认证DEMO 指纹识别功能
2019/12/13 Javascript
jQuery HTML获取内容和属性操作实例分析
2020/05/20 jQuery
用Python进行行为驱动开发的入门教程
2015/04/23 Python
Python中scatter函数参数及用法详解
2017/11/08 Python
Python删除n行后的其他行方法
2019/01/28 Python
用Python解数独的方法示例
2019/10/24 Python
浅析python表达式4+0.5值的数据类型
2020/02/26 Python
python 使用cx-freeze打包程序的实现
2020/03/14 Python
selenium学习教程之定位以及切换frame(iframe)
2021/01/04 Python
Python 实现劳拉游戏的实例代码(四连环、重力四子棋)
2021/03/03 Python
IE浏览器单独写CSS样式的几种方法
2014/10/14 HTML / CSS
Too Faced官网:美国知名彩妆品牌
2017/03/07 全球购物
美国职棒大联盟官方网上商店:MLBShop.com
2017/11/12 全球购物
Brasty波兰:香水、化妆品、手表网上商店
2019/04/15 全球购物
二手书店创业计划书
2014/01/16 职场文书
开学典礼演讲稿
2014/05/23 职场文书
ubuntu安装jupyter并设置远程访问的实现
2022/03/31 Python
nginx location 带斜杠【 / 】与不带的区别
2022/04/13 Servers