解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的filter()函数的用法
Apr 27 Python
python开发环境PyScripter中文乱码问题解决方案
Sep 11 Python
Python制作Windows系统服务
Mar 25 Python
浅谈function(函数)中的动态参数
Apr 30 Python
python实现上传下载文件功能
Nov 19 Python
Python 网页解析HTMLParse的实例详解
Aug 10 Python
如何优雅地改进Django中的模板碎片缓存详解
Jul 04 Python
python函数修饰符@的使用方法解析
Sep 02 Python
Python处理PDF与CDF实例
Feb 26 Python
浅谈Python中的异常和JSON读写数据的实现
Feb 27 Python
python 如何在list中找Topk的数值和索引
May 20 Python
python周期任务调度工具Schedule使用详解
Nov 23 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
phpMyAdmin 安装配置方法和问题解决
2009/06/08 PHP
php缩放图片(根据宽高的等比例缩放)实例介绍
2013/06/09 PHP
如何修改和添加Apache的默认站点目录
2013/07/05 PHP
windows下PHP_intl.dll正确配置方法(apache2.2+php5.3.5)
2014/01/14 PHP
javascript 隔行换色函数代码
2010/10/24 Javascript
JS 实现列表与多选框选择附预览动画
2014/10/29 Javascript
jquery实现页面百叶窗走马灯式翻滚显示效果的方法
2015/03/12 Javascript
javascript实现表单提交后,提交按钮不可用的方法
2015/04/18 Javascript
在JavaScript的正则表达式中使用exec()方法
2015/06/16 Javascript
Jquery实现跨域异步上传文件总结
2017/02/03 Javascript
jQuery实现给input绑定回车事件的方法
2017/02/09 Javascript
jQuery树插件zTree使用方法详解
2017/05/02 jQuery
jQuery遍历节点方法汇总(推荐)
2017/05/13 jQuery
使用vue构建一个上传图片表单
2017/07/04 Javascript
vue-router懒加载速度缓慢问题及解决方法
2018/11/25 Javascript
微信小程序之swiper滑动面板用法示例
2018/12/04 Javascript
JS正则表达式封装与使用操作示例
2019/05/15 Javascript
vue实现防抖的实例代码
2021/01/11 Vue.js
Python获取运行目录与当前脚本目录的方法
2015/06/01 Python
python实现基于SVM手写数字识别功能
2020/05/27 Python
python 定义n个变量方法 (变量声明自动化)
2018/11/10 Python
pycharm中使用anaconda部署python环境的方法步骤
2018/12/19 Python
Python 元组操作总结
2019/09/18 Python
python3 webp转gif格式的实现示例
2019/12/10 Python
python中的subprocess.Popen()使用详解
2019/12/25 Python
使用Jupyter notebooks上传文件夹或大量数据到服务器
2020/04/14 Python
python爬虫如何解决图片验证码
2021/02/14 Python
Public Desire美国/加拿大:全球性的在线鞋类品牌
2018/12/17 全球购物
Nordgreen美国官网:在线购买极简主义斯堪的纳维亚手表
2019/07/24 全球购物
SQL Server数据库笔试题和答案
2016/02/04 面试题
满月酒答谢词
2014/01/14 职场文书
财务科科长岗位职责
2014/03/10 职场文书
项目申请汇报材料
2014/08/16 职场文书
民事诉讼代理授权委托书
2014/10/11 职场文书
诉讼授权委托书
2014/10/15 职场文书
python可视化大屏库big_screen示例详解
2021/11/23 Python