解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中字典(dict)和列表(list)的排序方法实例
Jun 16 Python
跟老齐学Python之大话题小函数(2)
Oct 10 Python
使用Python读写及压缩和解压缩文件的示例
Jul 08 Python
python机器学习理论与实战(四)逻辑回归
Jan 19 Python
关于python之字典的嵌套,递归调用方法
Jan 21 Python
通过python爬虫赚钱的方法
Jan 29 Python
Python实现图片转字符画的代码实例
Feb 22 Python
使用Python检测文章抄袭及去重算法原理解析
Jun 14 Python
Django基础知识 URL路由系统详解
Jul 18 Python
简单了解python中的f.b.u.r函数
Nov 02 Python
Django如何使用jwt获取用户信息
Apr 21 Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
一个原生的用户等级的进度条
2010/07/03 Javascript
nodejs分页类代码分享
2014/06/17 NodeJs
ajax的分页查询示例(不刷新页面)
2017/01/11 Javascript
jquery获取select,option所有的value和text的实例
2017/03/06 Javascript
js弹出窗口简单实现代码
2017/03/22 Javascript
JavaScript通过filereader接口读取文件
2017/05/10 Javascript
Vue.js做select下拉列表的实例(ul-li标签仿select标签)
2018/03/02 Javascript
axios发送post请求,提交图片类型表单数据方法
2018/03/16 Javascript
Layui给switch添加响应事件的例子
2019/09/03 Javascript
微信小程序实现蓝牙打印
2019/09/23 Javascript
解决父组件将子组件作为弹窗调用只执行一次created的问题
2020/07/24 Javascript
[08:44]DOTA2发布会群星聚首 我们都是刀塔人
2014/03/21 DOTA
[01:15:44]首部DOTA2纪录片今日23时全网上映
2014/03/19 DOTA
[54:08]LGD女子刀塔学院 DOTA2炼金术士教学
2014/01/09 DOTA
Python和Ruby中each循环引用变量问题(一个隐秘BUG?)
2014/06/04 Python
python爬虫之百度API调用方法
2017/06/11 Python
Django Admin 实现外键过滤的方法
2017/09/29 Python
Python排序搜索基本算法之堆排序实例详解
2017/12/08 Python
利用Hyperic调用Python实现进程守护
2018/01/02 Python
一篇文章快速了解Python的GIL
2018/01/12 Python
python 随机生成10位数密码的实现代码
2019/06/27 Python
python实现比对美团接口返回数据和本地mongo数据是否一致示例
2019/08/09 Python
解决Atom安装Hydrogen无法运行python3的问题
2019/08/28 Python
智能旅行箱:Horizn Studios
2018/04/30 全球购物
日本著名化妆品零售网站:Cosme Land
2019/03/01 全球购物
外语专业毕业生自我评价分享
2013/10/05 职场文书
教师专业自荐信
2014/05/31 职场文书
中学生社会实践活动总结
2014/07/03 职场文书
春节超市活动方案
2014/08/14 职场文书
2014年控辍保学工作总结
2014/12/08 职场文书
2014年居委会工作总结
2014/12/09 职场文书
上课迟到检讨书
2015/05/06 职场文书
钢琴师观后感
2015/06/12 职场文书
母亲去世追悼词
2015/06/23 职场文书
工作违纪的检讨书范文
2019/07/09 职场文书
用python基于appium模块开发一个自动收取能量的小助手
2021/09/25 Python