解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python高效编程技巧
Jan 07 Python
Python数字图像处理之霍夫线变换实现详解
Jan 12 Python
Python温度转换实例分析
Jan 17 Python
Python简单计算文件MD5值的方法示例
Apr 11 Python
python 将数据保存为excel的xls格式(实例讲解)
May 03 Python
python基于物品协同过滤算法实现代码
May 31 Python
pandas 空的dataframe 插入列名的示例
Oct 30 Python
对python当中不在本路径的py文件的引用详解
Dec 15 Python
python利用插值法对折线进行平滑曲线处理
Dec 25 Python
Python跑循环时内存泄露的解决方法
Jan 13 Python
利用python绘制数据曲线图的实现
Apr 09 Python
python中os.path.join()函数实例用法
May 26 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
php session和cookie使用说明
2010/04/07 PHP
php在程序中将网页生成word文档并提供下载的代码
2012/10/09 PHP
phpmailer绑定邮箱的实现方法
2016/12/01 PHP
PHP生成(支持多模板)二维码海报代码
2018/04/30 PHP
javascript间隔刷新的简单实例
2013/11/14 Javascript
jQuery实现仿Alipay支付宝首页全屏焦点图切换特效
2015/05/04 Javascript
浅谈Javascript数组(推荐)
2016/05/17 Javascript
Angular2从搭建环境到开发步骤详解
2016/10/17 Javascript
js中获取键盘按下键值event.keyCode、event.charCode和event.which的兼容性详解
2017/03/15 Javascript
jQuery轻松实现无缝轮播效果
2017/03/22 jQuery
原生js 封装get ,post, delete 请求的实例
2017/08/11 Javascript
vue系列之动态路由详解【原创】
2017/09/10 Javascript
Vue响应式原理深入解析及注意事项
2017/12/11 Javascript
JavaScript实现邮箱后缀提示功能的示例代码
2018/12/13 Javascript
JS开发 富文本编辑器TinyMCE详解
2019/07/19 Javascript
微信JS-SDK实现微信会员卡功能(给用户微信卡包里发送会员卡)
2019/07/25 Javascript
swiperjs实现导航与tab页的联动
2020/12/13 Javascript
python基础入门学习笔记(Python环境搭建)
2016/01/13 Python
Python3调用百度AI识别图片中的文字功能示例【测试可用】
2019/03/13 Python
django 链接多个数据库 并使用原生sql实现
2020/03/28 Python
HTML5播放实现rtmp流直播
2020/06/16 HTML / CSS
捷克电器和DJ设备网上商店:Electronic-star
2017/07/18 全球购物
波兰快递服务:Globkurier.pl
2019/11/08 全球购物
Everlast官网:拳击、综合格斗和健身相关的体育用品
2020/08/03 全球购物
JSF的标签库有哪些
2012/04/27 面试题
学雷锋树新风演讲稿
2014/05/10 职场文书
专项法律服务方案
2014/06/11 职场文书
明星邀请函
2015/02/02 职场文书
银行催款通知书
2015/04/17 职场文书
nginx里的rewrite跳转的实现
2021/03/31 Servers
go语言求任意类型切片的长度操作
2021/04/26 Golang
mysql数据库入门第一步之创建表
2021/05/14 MySQL
分析设计模式之模板方法Java实现
2021/06/23 Java/Android
Django路由层如何获取正确的url
2021/07/15 Python
Python OpenCV之常用滤波器使用详解
2022/04/07 Python
React自定义hook的方法
2022/06/25 Javascript