解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现更改图片尺寸大小的方法(基于Pillow包)
Sep 19 Python
关于python的bottle框架跨域请求报错问题的处理方法
Mar 19 Python
Python实现将SQLite中的数据直接输出为CVS的方法示例
Jul 13 Python
利用Python实现Windows下的鼠标键盘模拟的实例代码
Jul 13 Python
python TKinter获取文本框内容的方法
Oct 11 Python
python opencv minAreaRect 生成最小外接矩形的方法
Jul 01 Python
django 2.2和mysql使用的常见问题
Jul 18 Python
Django获取该数据的上一条和下一条方法
Aug 12 Python
Jupyter notebook设置背景主题,字体大小及自动补全代码的操作
Apr 13 Python
Python中内建模块collections如何使用
May 27 Python
利用Python中的Xpath实现一个在线汇率转换器
Sep 09 Python
python Scrapy爬虫框架的使用
Jan 21 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
php中eval函数的危害与正确禁用方法
2014/06/30 PHP
php实现将任意进制数转换成10进制的方法
2015/04/17 PHP
浅析PHP中call user func()函数及如何使用call user func调用自定义函数
2015/11/05 PHP
php的闭包(Closure)匿名函数初探
2016/02/14 PHP
Laravel5.1框架注册中间件的三种场景详解
2019/07/09 PHP
PHP连接MySQL数据库操作代码实例解析
2020/07/11 PHP
js function定义函数使用心得
2010/04/15 Javascript
JQuery验证jsp页面属性是否为空(实例代码)
2013/11/08 Javascript
jquery实现的一个简单进度条效果实例
2014/05/12 Javascript
JS+CSS实现电子商务网站导航模板效果代码
2015/09/10 Javascript
基于JavaScript实现 网页切出 网站title变化代码
2016/04/03 Javascript
jquery操作ID带有变量的节点实例
2016/12/07 Javascript
VUE2.0中Jsonp的使用方法
2018/05/22 Javascript
ES6新增的数组知识实例小结
2020/05/23 Javascript
Python中input与raw_input 之间的比较
2017/08/20 Python
python使用Apriori算法进行关联性解析
2017/12/21 Python
Python简单生成随机数的方法示例
2018/03/31 Python
python 扩展print打印文件路径和当前时间信息的实例代码
2019/10/11 Python
python 进制转换 int、bin、oct、hex的原理
2021/01/13 Python
Html5剪切板功能的实现代码
2018/06/29 HTML / CSS
浅析border-radius如何兼容IE
2016/04/19 HTML / CSS
乌克兰珠宝大卖场:Zlato.ua
2020/09/27 全球购物
我看到了用指针调用函数的不同语法形式
2014/07/16 面试题
Hashtable 添加内容的方式有哪几种,有什么区别?
2012/04/08 面试题
新入职员工的自我介绍演讲稿
2014/01/02 职场文书
内业资料员岗位职责
2014/01/04 职场文书
经贸韩语专业大学生职业规划
2014/02/14 职场文书
《值日生》教学反思
2014/02/17 职场文书
餐饮总经理岗位职责
2014/03/07 职场文书
《搭石》教学反思
2014/04/07 职场文书
初三学习计划书范文
2014/04/30 职场文书
入股协议书范本
2014/11/01 职场文书
党的群众路线教育实践活动心得体会范文
2014/11/05 职场文书
2019脱贫攻坚工作总结报告范本!
2019/08/06 职场文书
Flask response响应的具体使用
2021/07/15 Python
html网页引入svg图片的4种方式
2022/08/05 HTML / CSS