解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python __setattr__、 __getattr__、 __delattr__、__call__用法示例
Mar 06 Python
python生成随机mac地址的方法
Mar 16 Python
Python中使用装饰器时需要注意的一些问题
May 11 Python
Python导出数据到Excel可读取的CSV文件的方法
May 12 Python
Python实现PS滤镜特效之扇形变换效果示例
Jan 26 Python
编写多线程Python服务器 最适合基础
Sep 14 Python
详解如何用django实现redirect的几种方法总结
Nov 22 Python
python实现月食效果实例代码
Jun 18 Python
Python子进程subpocess原理及用法解析
Jul 16 Python
Python实现简单猜数字游戏
Feb 03 Python
两行代码解决Jupyter Notebook中文不能显示的问题
Apr 24 Python
只用20行Python代码实现屏幕录制功能
Jun 02 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
关于PHP中Object对象的笔记分享
2011/06/28 PHP
PHP递归算法的详细示例分析
2013/02/19 PHP
利用php获取服务器时间的实现代码
2013/06/07 PHP
使用cookie实现统计访问者登陆次数
2013/06/08 PHP
php导出word格式数据的代码实例
2013/11/25 PHP
php 无限级分类 获取顶级分类ID
2016/03/13 PHP
学习JS面向对象成果 借国庆发布个最新作品与大家交流
2009/10/03 Javascript
JS和jquery获取各种屏幕的宽度和高度的代码
2013/08/02 Javascript
WEB前端设计师常用工具集锦
2014/12/09 Javascript
JavaScript使用replace函数替换字符串的方法
2015/04/06 Javascript
使用AngularJS中的SCE来防止XSS攻击的方法
2015/06/18 Javascript
JavaScript数组方法大全(推荐)
2016/07/05 Javascript
使用BootStrap实现用户登录界面UI
2016/08/10 Javascript
ionic实现滑动的三种方式
2016/08/27 Javascript
如何使用headjs来管理和异步加载js
2016/11/29 Javascript
layer实现关闭弹出层刷新父界面功能详解
2017/11/15 Javascript
js里面的变量范围分享
2020/07/18 Javascript
python获取命令行输入参数列表的实例代码
2018/06/23 Python
利用python实现在微信群刷屏的方法
2019/02/21 Python
Python3实现的简单三级菜单功能示例
2019/03/12 Python
Django 自定义分页器的实现代码
2019/11/24 Python
Python模拟伯努利试验和二项分布代码实例
2020/05/27 Python
如何在keras中添加自己的优化器(如adam等)
2020/06/19 Python
在keras中对单一输入图像进行预测并返回预测结果操作
2020/07/09 Python
python matplotlib库的基本使用
2020/09/23 Python
python里glob模块知识点总结
2021/01/05 Python
html5页面结构_动力节点Java学院整理
2017/07/10 HTML / CSS
浅谈Html5移动端ios/Android兼容性总结
2018/06/01 HTML / CSS
Get The Label中文官网:英国运动时尚购物平台
2017/04/19 全球购物
倩碧英国官网:Clinique英国
2018/08/10 全球购物
自荐信格式技巧有哪些呢
2013/11/19 职场文书
考试不及格检讨书
2014/01/09 职场文书
我心目中的好老师活动方案
2014/08/19 职场文书
2014年作风建设剖析材料
2014/10/23 职场文书
民事和解协议书格式
2014/11/29 职场文书
孝老爱亲事迹材料
2014/12/24 职场文书