基于Keras的格式化输出Loss实现方式


Posted in Python onJune 17, 2020

在win7 64位,Anaconda安装的Python3.6.1下安装的TensorFlow与Keras,Keras的backend为TensorFlow。在运行Mask R-CNN时,在进行调试时想知道PyCharm (Python IDE)底部窗口输出的Loss格式是在哪里定义的,如下图红框中所示:

基于Keras的格式化输出Loss实现方式

图1 训练过程的Loss格式化输出

在上图红框中,Loss的输出格式是在哪里定义的呢?有一点是明确的,即上图红框中的内容是在训练的时候输出的。那么先来看一下Mask R-CNN的训练过程。Keras以Numpy数组作为输入数据和标签的数据类型。训练模型一般使用 fit 函数。然而由于Mask R-CNN训练数据巨大,不能一次性全部载入,否则太消耗内存。于是采用生成器的方式一次载入一个batch的数据,而且是在用到这个batch的数据才开始载入的,那么它的训练函数如下:

self.keras_model.fit_generator(
   train_generator,
   initial_epoch=self.epoch,
   epochs=epochs,
   steps_per_epoch=self.config.STEPS_PER_EPOCH,
   callbacks=callbacks,
   validation_data=val_generator,
   validation_steps=self.config.VALIDATION_STEPS,
   max_queue_size=100,
   workers=workers,
   use_multiprocessing=False,
  )

这里训练模型的函数相应的为 fit_generator 函数。注意其中的参数callbacks=callbacks,这个参数在输出红框中的内容起到了关键性的作用。下面看一下callbacks的值:

# Callbacks
  callbacks = [
   keras.callbacks.TensorBoard(log_dir=self.log_dir,
          histogram_freq=0, write_graph=True, write_images=False),
   keras.callbacks.ModelCheckpoint(self.checkpoint_path,
           verbose=0, save_weights_only=True),
  ]

在输出红框中的内容所需的数据均保存在self.log_dir下。然后调试进入self.keras_model.fit_generator函数,进入keras,legacy.interfaces的legacy_support(func)函数,如下所示:

def legacy_support(func):
  @six.wraps(func)
  def wrapper(*args, **kwargs):
   if object_type == 'class':
    object_name = args[0].__class__.__name__
   else:
    object_name = func.__name__
   if preprocessor:
    args, kwargs, converted = preprocessor(args, kwargs)
   else:
    converted = []
   if check_positional_args:
    if len(args) > len(allowed_positional_args) + 1:
     raise TypeError('`' + object_name +
         '` can accept only ' +
         str(len(allowed_positional_args)) +
         ' positional arguments ' +
         str(tuple(allowed_positional_args)) +
         ', but you passed the following '
         'positional arguments: ' +
         str(list(args[1:])))
   for key in value_conversions:
    if key in kwargs:
     old_value = kwargs[key]
     if old_value in value_conversions[key]:
      kwargs[key] = value_conversions[key][old_value]
   for old_name, new_name in conversions:
    if old_name in kwargs:
     value = kwargs.pop(old_name)
     if new_name in kwargs:
      raise_duplicate_arg_error(old_name, new_name)
     kwargs[new_name] = value
     converted.append((new_name, old_name))
   if converted:
    signature = '`' + object_name + '('
    for i, value in enumerate(args[1:]):
     if isinstance(value, six.string_types):
      signature += '"' + value + '"'
     else:
      if isinstance(value, np.ndarray):
       str_val = 'array'
      else:
       str_val = str(value)
      if len(str_val) > 10:
       str_val = str_val[:10] + '...'
      signature += str_val
     if i < len(args[1:]) - 1 or kwargs:
      signature += ', '
    for i, (name, value) in enumerate(kwargs.items()):
     signature += name + '='
     if isinstance(value, six.string_types):
      signature += '"' + value + '"'
     else:
      if isinstance(value, np.ndarray):
       str_val = 'array'
      else:
       str_val = str(value)
      if len(str_val) > 10:
       str_val = str_val[:10] + '...'
      signature += str_val
     if i < len(kwargs) - 1:
      signature += ', '
    signature += ')`'
    warnings.warn('Update your `' + object_name +
        '` call to the Keras 2 API: ' + signature, stacklevel=2)
   return func(*args, **kwargs)
  wrapper._original_function = func
  return wrapper
 return legacy_support

在上述代码的倒数第4行的return func(*args, **kwargs)处返回func,func为fit_generator函数,现调试进入fit_generator函数,该函数定义在keras.engine.training模块内的fit_generator函数,调试进入函数callbacks.on_epoch_begin(epoch),如下所示:

# Construct epoch logs.
   epoch_logs = {}
   while epoch < epochs:
    for m in self.stateful_metric_functions:
     m.reset_states()
    callbacks.on_epoch_begin(epoch)

调试进入到callbacks.on_epoch_begin(epoch)函数,进入on_epoch_begin函数,如下所示:

def on_epoch_begin(self, epoch, logs=None):
  """Called at the start of an epoch.
  # Arguments
   epoch: integer, index of epoch.
   logs: dictionary of logs.
  """
  logs = logs or {}
  for callback in self.callbacks:
   callback.on_epoch_begin(epoch, logs)
  self._delta_t_batch = 0.
  self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
  self._delta_ts_batch_end = deque([], maxlen=self.queue_length)

在上述函数on_epoch_begin中调试进入callback.on_epoch_begin(epoch, logs)函数,转到类ProgbarLogger(Callback)中定义的on_epoch_begin函数,如下所示:

class ProgbarLogger(Callback):
 """Callback that prints metrics to stdout.
 # Arguments
  count_mode: One of "steps" or "samples".
   Whether the progress bar should
   count samples seen or steps (batches) seen.
  stateful_metrics: Iterable of string names of metrics that
   should *not* be averaged over an epoch.
   Metrics in this list will be logged as-is.
   All others will be averaged over time (e.g. loss, etc).
 # Raises
  ValueError: In case of invalid `count_mode`.
 """
 
 def __init__(self, count_mode='samples',
     stateful_metrics=None):
  super(ProgbarLogger, self).__init__()
  if count_mode == 'samples':
   self.use_steps = False
  elif count_mode == 'steps':
   self.use_steps = True
  else:
   raise ValueError('Unknown `count_mode`: ' + str(count_mode))
  if stateful_metrics:
   self.stateful_metrics = set(stateful_metrics)
  else:
   self.stateful_metrics = set()
 
 def on_train_begin(self, logs=None):
  self.verbose = self.params['verbose']
  self.epochs = self.params['epochs']
 
 def on_epoch_begin(self, epoch, logs=None):
  if self.verbose:
   print('Epoch %d/%d' % (epoch + 1, self.epochs))
   if self.use_steps:
    target = self.params['steps']
   else:
    target = self.params['samples']
   self.target = target
   self.progbar = Progbar(target=self.target,
         verbose=self.verbose,
         stateful_metrics=self.stateful_metrics)
  self.seen = 0

在上述代码的

print('Epoch %d/%d' % (epoch + 1, self.epochs))

输出

Epoch 1/40(如红框中所示内容的第一行)。

然后返回到keras.engine.training模块内的fit_generator函数,执行到self.train_on_batch函数,如下所示:

outs = self.train_on_batch(x, y,
     sample_weight=sample_weight,
     class_weight=class_weight)
 
     if not isinstance(outs, list):
      outs = [outs]
     for l, o in zip(out_labels, outs):
      batch_logs[l] = o
 
     callbacks.on_batch_end(batch_index, batch_logs)
 
     batch_index += 1
     steps_done += 1

调试进入上述代码中的callbacks.on_batch_end(batch_index, batch_logs)函数,进入到on_batch_end函数后,该函数的定义如下所示:

def on_batch_end(self, batch, logs=None):
  """Called at the end of a batch.
  # Arguments
   batch: integer, index of batch within the current epoch.
   logs: dictionary of logs.
  """
  logs = logs or {}
  if not hasattr(self, '_t_enter_batch'):
   self._t_enter_batch = time.time()
  self._delta_t_batch = time.time() - self._t_enter_batch
  t_before_callbacks = time.time()
  for callback in self.callbacks:
   callback.on_batch_end(batch, logs)
  self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
  delta_t_median = np.median(self._delta_ts_batch_end)
  if (self._delta_t_batch > 0. and
   (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
   warnings.warn('Method on_batch_end() is slow compared '
       'to the batch update (%f). Check your callbacks.'
       % delta_t_median)

接着继续调试进入上述代码中的callback.on_batch_end(batch, logs)函数,进入到在类中ProgbarLogger(Callback)定义的on_batch_end函数,如下所示:

def on_batch_end(self, batch, logs=None):
  logs = logs or {}
  batch_size = logs.get('size', 0)
  if self.use_steps:
   self.seen += 1
  else:
   self.seen += batch_size
 
  for k in self.params['metrics']:
   if k in logs:
    self.log_values.append((k, logs[k]))
 
  # Skip progbar update for the last batch;
  # will be handled by on_epoch_end.
  if self.verbose and self.seen < self.target:
   self.progbar.update(self.seen, self.log_values)

然后执行到上述代码的最后一行self.progbar.update(self.seen, self.log_values),调试进入update函数,该函数定义在模块keras.utils.generic_utils中的类Progbar(object)定义的函数。类的定义及方法如下所示:

class Progbar(object):
 """Displays a progress bar.
 # Arguments
  target: Total number of steps expected, None if unknown.
  width: Progress bar width on screen.
  verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
  stateful_metrics: Iterable of string names of metrics that
   should *not* be averaged over time. Metrics in this list
   will be displayed as-is. All others will be averaged
   by the progbar before display.
  interval: Minimum visual progress update interval (in seconds).
 """
 
 def __init__(self, target, width=30, verbose=1, interval=0.05,
     stateful_metrics=None):
  self.target = target
  self.width = width
  self.verbose = verbose
  self.interval = interval
  if stateful_metrics:
   self.stateful_metrics = set(stateful_metrics)
  else:
   self.stateful_metrics = set()
 
  self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
         sys.stdout.isatty()) or
         'ipykernel' in sys.modules)
  self._total_width = 0
  self._seen_so_far = 0
  self._values = collections.OrderedDict()
  self._start = time.time()
  self._last_update = 0
 
 def update(self, current, values=None):
  """Updates the progress bar.
  # Arguments
   current: Index of current step.
   values: List of tuples:
    `(name, value_for_last_step)`.
    If `name` is in `stateful_metrics`,
    `value_for_last_step` will be displayed as-is.
    Else, an average of the metric over time will be displayed.
  """
  values = values or []
  for k, v in values:
   if k not in self.stateful_metrics:
    if k not in self._values:
     self._values[k] = [v * (current - self._seen_so_far),
          current - self._seen_so_far]
    else:
     self._values[k][0] += v * (current - self._seen_so_far)
     self._values[k][1] += (current - self._seen_so_far)
   else:
    # Stateful metrics output a numeric value. This representation
    # means "take an average from a single value" but keeps the
    # numeric formatting.
    self._values[k] = [v, 1]
  self._seen_so_far = current
 
  now = time.time()
  info = ' - %.0fs' % (now - self._start)
  if self.verbose == 1:
   if (now - self._last_update < self.interval and
     self.target is not None and current < self.target):
    return
 
   prev_total_width = self._total_width
   if self._dynamic_display:
    sys.stdout.write('\b' * prev_total_width)
    sys.stdout.write('\r')
   else:
    sys.stdout.write('\n')
 
   if self.target is not None:
    numdigits = int(np.floor(np.log10(self.target))) + 1
    barstr = '%%%dd/%d [' % (numdigits, self.target)
    bar = barstr % current
    prog = float(current) / self.target
    prog_width = int(self.width * prog)
    if prog_width > 0:
     bar += ('=' * (prog_width - 1))
     if current < self.target:
      bar += '>'
     else:
      bar += '='
    bar += ('.' * (self.width - prog_width))
    bar += ']'
   else:
    bar = '%7d/Unknown' % current
 
   self._total_width = len(bar)
   sys.stdout.write(bar)
 
   if current:
    time_per_unit = (now - self._start) / current
   else:
    time_per_unit = 0
   if self.target is not None and current < self.target:
    eta = time_per_unit * (self.target - current)
    if eta > 3600:
     eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60)
    elif eta > 60:
     eta_format = '%d:%02d' % (eta // 60, eta % 60)
    else:
     eta_format = '%ds' % eta
 
    info = ' - ETA: %s' % eta_format
   else:
    if time_per_unit >= 1:
     info += ' %.0fs/step' % time_per_unit
    elif time_per_unit >= 1e-3:
     info += ' %.0fms/step' % (time_per_unit * 1e3)
    else:
     info += ' %.0fus/step' % (time_per_unit * 1e6)
 
   for k in self._values:
    info += ' - %s:' % k
    if isinstance(self._values[k], list):
     avg = np.mean(
      self._values[k][0] / max(1, self._values[k][1]))
     if abs(avg) > 1e-3:
      info += ' %.4f' % avg
     else:
      info += ' %.4e' % avg
    else:
     info += ' %s' % self._values[k]
 
   self._total_width += len(info)
   if prev_total_width > self._total_width:
    info += (' ' * (prev_total_width - self._total_width))
 
   if self.target is not None and current >= self.target:
    info += '\n'
 
   sys.stdout.write(info)
   sys.stdout.flush()
 
  elif self.verbose == 2:
   if self.target is None or current >= self.target:
    for k in self._values:
     info += ' - %s:' % k
     avg = np.mean(
      self._values[k][0] / max(1, self._values[k][1]))
     if avg > 1e-3:
      info += ' %.4f' % avg
     else:
      info += ' %.4e' % avg
    info += '\n'
 
    sys.stdout.write(info)
    sys.stdout.flush()
 
  self._last_update = now
 
 def add(self, n, values=None):
  self.update(self._seen_so_far + n, values)

重点是上述代码中的update(self, current, values=None)函数,在该函数内设置断点,即可调入该函数。下面重点分析上述代码中的几个输出条目:

1. sys.stdout.write('\n') #换行

2. sys.stdout.write('bar') #输出 [..................],其中bar= [..................];

3. sys.stdout.write(info) #输出loss格式,其中info='- ETA:...';

4. sys.stdout.flush() #刷新缓存,立即得到输出。

通过对Mask R-CNN代码的调试分析可知,图1中的红框中的训练过程中的Loss格式化输出是由built-in模块实现的。若想得到类似的格式化输出,关键在self.keras_model.fit_generator函数中传入callbacks参数和callbacks中内容的定义。

以上这篇基于Keras的格式化输出Loss实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python脚本将Bing的每日图片作为桌面的教程
May 04 Python
Python中title()方法的使用简介
May 20 Python
python实现爬取千万淘宝商品的方法
Jun 30 Python
asyncio 的 coroutine对象 与 Future对象使用指南
Sep 11 Python
老生常谈Python基础之字符编码
Jun 14 Python
python之Character string(实例讲解)
Sep 25 Python
Python中装饰器高级用法详解
Dec 25 Python
Python GUI Tkinter简单实现个性签名设计
Jun 19 Python
python自定义时钟类、定时任务类
Feb 22 Python
Anaconda和ipython环境适配的实现
Apr 22 Python
Pycharm常用快捷键总结及配置方法
Nov 14 Python
python内置进制转换函数的操作
Jun 02 Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 #Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
You might like
Yii+upload实现AJAX上传图片的方法
2016/07/13 PHP
Laravel接收前端ajax传来的数据的实例代码
2017/07/20 PHP
PHP判断json格式是否正确的实现代码
2017/09/20 PHP
php基于环形链表解决约瑟夫环问题示例
2017/11/07 PHP
自动更新作用
2006/10/08 Javascript
Javascript 实现TreeView CheckBox全选效果
2010/01/11 Javascript
jquery实现文字由下到上循环滚动的实例代码
2013/08/09 Javascript
javascript使用location.search的示例
2013/11/05 Javascript
js完美实现@提到好友特效(兼容各大浏览器)
2015/03/16 Javascript
JavaScript+html5 canvas实现本地截图教程
2020/04/16 Javascript
jQuery实现背景弹性滚动的导航效果
2016/06/01 Javascript
nodejs利用ajax实现网页无刷新上传图片实例代码
2017/06/06 NodeJs
vue的一个分页组件的示例代码
2017/12/25 Javascript
微信小程序实现跑马灯效果完整代码(附效果图)
2018/05/30 Javascript
浅谈在vue中使用mint-ui swipe遇到的问题
2018/09/27 Javascript
javascript 内存模型实例详解
2020/04/18 Javascript
vscode 插件开发 + vue的操作方法
2020/06/05 Javascript
[01:32]dota2拉比克至宝(222)
2018/12/20 DOTA
python爬取网站数据保存使用的方法
2013/11/20 Python
零基础写python爬虫之打包生成exe文件
2014/11/06 Python
python实现查找excel里某一列重复数据并且剔除后打印的方法
2015/05/26 Python
Python实现蒙特卡洛算法小实验过程详解
2019/07/12 Python
python pyqtgraph 保存图片到本地的实例
2020/03/14 Python
Python面向对象程序设计之继承、多态原理与用法详解
2020/03/23 Python
pandas 强制类型转换 df.astype实例
2020/04/09 Python
HTML5文档结构标签
2017/04/21 HTML / CSS
HTML5中在title标题标签里设置小图标的方法
2020/06/23 HTML / CSS
Avène雅漾美国官方网站:敏感肌肤护理专家
2016/10/24 全球购物
护士自荐信
2013/10/25 职场文书
晚会邀请函范文
2014/01/24 职场文书
怀念母亲教学反思
2014/04/28 职场文书
数字化校园建设方案
2014/05/03 职场文书
保证书格式
2015/01/16 职场文书
文明和谐家庭事迹材料(2016精选版)
2016/02/29 职场文书
Nginx 匹配方式
2022/05/15 Servers
MySQL8.0 Undo Tablespace管理详解
2022/06/16 MySQL