解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python启动办公软件进程(word、excel、ppt、以及wps的et、wps、wpp)
Apr 09 Python
Python字符串和文件操作常用函数分析
Apr 08 Python
在Python中操作时间之mktime()方法的使用教程
May 22 Python
python验证码识别的实例详解
Sep 09 Python
使用python和Django完成博客数据库的迁移方法
Jan 05 Python
Sanic框架蓝图用法实例分析
Jul 17 Python
对python中url参数编码与解码的实例详解
Jul 25 Python
用Python徒手撸一个股票回测框架搭建【推荐】
Aug 05 Python
Django为窗体加上防机器人的验证码功能过程解析
Aug 14 Python
从numpy数组中取出满足条件的元素示例
Nov 26 Python
Pytorch evaluation每次运行结果不同的解决
Jan 02 Python
python urllib库的使用详解
Apr 13 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
如何对PHP程序中的常见漏洞进行攻击(上)
2006/10/09 PHP
PHP树的深度编历生成迷宫及A*自动寻路算法实例分析
2015/03/10 PHP
PHP 5.6.11 访问SQL Server2008R2的几种情况详解
2016/08/08 PHP
PHP生成唯一ID之SnowFlake算法
2016/12/17 PHP
Javascript this关键字使用分析
2008/10/21 Javascript
Prototype Date对象 学习
2009/07/12 Javascript
Javascript 键盘keyCode键码值表
2009/12/24 Javascript
js substr支持中文截取函数代码(中文是双字节)
2013/04/17 Javascript
深入理解javascript的执行顺序
2014/04/04 Javascript
判断及设置浏览器全屏模式
2014/04/20 Javascript
jQuery实现异步获取json数据的2种方式
2014/08/29 Javascript
Javascript函数的参数
2015/07/16 Javascript
Easyui 之 Treegrid 笔记
2016/04/29 Javascript
jQuery实现简洁的轮播图效果实例
2016/09/07 Javascript
js手动播放图片实现图片轮播效果
2016/09/17 Javascript
Bootstrap文件上传组件之bootstrap fileinput
2016/11/25 Javascript
NodeJs使用Mysql模块实现事务处理实例
2017/05/31 NodeJs
node之本地服务器图片上传的方法示例
2019/03/26 Javascript
使用Vue 实现滑动验证码功能
2019/06/27 Javascript
改进 JavaScript 和 Rust 的互操作性并深入认识 wasm-bindgen 组件
2019/07/13 Javascript
原生js实现无缝轮播图效果
2021/01/28 Javascript
手写Vue2.0 数据劫持的示例
2021/03/04 Vue.js
[42:24]完美世界DOTA2联赛PWL S2 LBZS vs FTD.C 第三场 11.27
2020/12/01 DOTA
35个Python编程小技巧
2014/04/01 Python
pandas实现to_sql将DataFrame保存到数据库中
2019/07/03 Python
python实现七段数码管和倒计时效果
2019/11/23 Python
python将四元数变换为旋转矩阵的实例
2019/12/04 Python
基于MSELoss()与CrossEntropyLoss()的区别详解
2020/01/02 Python
Python timeit模块的使用实践
2020/01/13 Python
Python能做什么
2020/06/02 Python
python Matplotlib数据可视化(2):详解三大容器对象与常用设置
2020/09/30 Python
基于OpenCV的网络实时视频流传输的实现
2020/11/15 Python
意大利在线购买隐形眼镜网站:VisionDirect.it
2019/03/18 全球购物
自我评价优秀范文分享
2013/11/30 职场文书
2015年教师党员自我评价材料
2015/03/04 职场文书
2019交通安全宣传标语集锦!
2019/06/28 职场文书