解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中requests模块的使用方法
Apr 08 Python
深入浅析Python字符编码
Nov 12 Python
django js实现部分页面刷新的示例代码
May 28 Python
python脚本监控Tomcat服务器的方法
Jul 06 Python
详解基于django实现的webssh简单例子
Jul 17 Python
python使用tornado实现登录和登出
Jul 28 Python
python处理DICOM并计算三维模型体积
Feb 26 Python
python实现单链表的方法示例
Sep 03 Python
python GUI库图形界面开发之PyQt5树形结构控件QTreeWidget详细使用方法与实例
Mar 02 Python
python由已知数组快速生成新数组的方法
Apr 08 Python
python如何提升爬虫效率
Sep 27 Python
python利用proxybroker构建爬虫免费IP代理池的实现
Feb 21 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
一个基于phpQuery的php通用采集类分享
2014/04/09 PHP
php版微信发红包接口用法示例
2016/09/23 PHP
CI框架实现框架前后端分离的方法详解
2016/12/30 PHP
PHP的HTTP客户端Guzzle简单使用方法分析
2019/10/30 PHP
javascript继承之为什么要继承
2012/11/10 Javascript
兼容IE和Firefox火狐的上下、左右循环无间断滚动JS代码
2013/04/19 Javascript
js 获取时间间隔实现代码
2014/05/12 Javascript
JavaScript函数详解
2015/02/27 Javascript
深入解析JavaScript中函数的Currying柯里化
2016/03/19 Javascript
JS与jQuery实现隔行变色的方法
2016/09/09 Javascript
深入理解JavaScript中的并行处理
2016/09/22 Javascript
JS实现探测网站链接的方法【测试可用】
2016/11/08 Javascript
Javascript实现倒计时时差效果
2017/05/18 Javascript
javascript 日期相减-在线教程(附代码)
2017/08/17 Javascript
vue 登录滑动验证实现代码
2018/08/24 Javascript
微信小程序之onLaunch与onload异步问题详解
2019/03/28 Javascript
vue-cli和v-charts实现可视化图表过程解析
2019/10/08 Javascript
node.js实现简单的压缩/解压缩功能示例
2019/11/05 Javascript
基于Vue2实现移动端图片上传、压缩、拖拽排序、拖拽删除功能
2021/01/05 Vue.js
python实现linux服务器批量修改密码并生成execl
2014/04/22 Python
Python 使用requests模块发送GET和POST请求的实现代码
2016/09/21 Python
python gdal安装与简单使用
2019/08/01 Python
Python将视频或者动态图gif逐帧保存为图片的方法
2019/09/10 Python
python程序如何进行保存
2020/07/03 Python
cookies应对python反爬虫知识点详解
2020/11/25 Python
你正在寻找的CSS3 动画技术
2011/07/27 HTML / CSS
利用纯CSS3实现动态的自行车特效源码
2017/01/20 HTML / CSS
小程序canvas中文字设置居中锚点
2019/04/16 HTML / CSS
员工试用期考核自我鉴定
2014/04/13 职场文书
疾病防治方案
2014/05/31 职场文书
三方股东合作协议书范本
2014/09/28 职场文书
关于拾金不昧的感谢信
2015/01/21 职场文书
初二数学教学反思
2016/02/17 职场文书
canvas多重阴影发光效果实现
2021/04/20 Javascript
关于CSS浮动与取消浮动的问题
2021/06/28 HTML / CSS
Apache SkyWalking 监控 MySQL Server 实战解析
2022/09/23 Servers