解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python随机生成手机号、数字的方法详解
Jul 21 Python
Tensorflow中的placeholder和feed_dict的使用
Jul 09 Python
Python3.4学习笔记之 idle 清屏扩展插件用法分析
Mar 01 Python
Python3.4学习笔记之常用操作符,条件分支和循环用法示例
Mar 01 Python
分析运行中的 Python 进程详细解析
Jun 22 Python
Python time库基本使用方法分析
Dec 13 Python
用python打开摄像头并把图像传回qq邮箱(Pyinstaller打包)
May 17 Python
关于keras中keras.layers.merge的用法说明
May 23 Python
Python3创建Django项目的几种方法(3种)
Jun 03 Python
Python如何使用27行代码绘制星星图
Jul 20 Python
Python如何读取、写入JSON数据
Jul 28 Python
Python如何telnet到网络设备
Feb 18 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
php smarty 二级分类代码和模版循环例子
2011/06/01 PHP
PHP5.3新特性小结
2016/02/14 PHP
轻松创建nodejs服务器(7):阻塞操作的实现
2014/12/18 NodeJs
javascript实现多栏闭合展开式广告位菜单效果实例
2015/08/05 Javascript
jQuery实现大转盘抽奖活动仿QQ音乐代码分享
2015/08/21 Javascript
jQuery实现响应鼠标滚动的动感菜单效果
2015/09/21 Javascript
jquery实现网页的页面平滑滚动效果代码
2015/11/02 Javascript
基于JavaScript实现一定时间后去执行一个函数
2015/12/14 Javascript
结合代码图文讲解JavaScript中的作用域与作用域链
2016/07/05 Javascript
input获取焦点时底部菜单被顶上来问题的解决办法
2017/01/24 Javascript
BootStrap 导航条实例代码
2017/05/18 Javascript
JavaScript输出所选择起始与结束日期的方法
2017/07/12 Javascript
JavaScript函数中的this四种绑定形式
2017/08/15 Javascript
JavaScript实现的前端AES加密解密功能【基于CryptoJS】
2018/08/28 Javascript
浅谈vue同一页面中拥有两个表单时,的验证问题
2018/09/18 Javascript
详解Vue项目在其他电脑npm run dev运行报错的解决方法
2018/10/29 Javascript
js实现图片跟随鼠标移动效果
2019/10/16 Javascript
[03:14]DOTA2斧王 英雄基础教程
2013/11/26 DOTA
Python读写unicode文件的方法
2015/07/10 Python
Python简单定义与使用字典dict的方法示例
2017/07/25 Python
Python初学时购物车程序练习实例(推荐)
2017/08/08 Python
python绘制封闭多边形教程
2020/02/18 Python
python生成并处理uuid的实现方式
2020/03/03 Python
python如何爬取动态网站
2020/09/09 Python
HTML5+Canvas+CSS3实现齐天大圣孙悟空腾云驾雾效果
2016/04/26 HTML / CSS
Viking Direct爱尔兰:办公用品和家具
2019/11/21 全球购物
JAVA中的关键字有什么特点
2014/03/07 面试题
上班迟到检讨书
2014/01/10 职场文书
金融系应届毕业生求职信
2014/05/26 职场文书
软件售后服务方案
2014/05/29 职场文书
2014年幼儿园德育工作总结
2014/12/17 职场文书
项目负责人岗位职责
2015/02/15 职场文书
大学生十八大感想
2015/08/11 职场文书
2016年师德师风学习心得体会
2016/01/12 职场文书
2016年学校“3.12”植树节活动总结
2016/03/16 职场文书
如何在向量化NumPy数组上进行移动窗口
2021/05/18 Python