Keras自定义实现带masking的meanpooling层方式


Posted in Python onJune 16, 2020

Keras确实是一大神器,代码可以写得非常简洁,但是最近在写LSTM和DeepFM的时候,遇到了一个问题:样本的长度不一样。对不定长序列的一种预处理方法是,首先对数据进行padding补0,然后引入keras的Masking层,它能自动对0值进行过滤。

问题在于keras的某些层不支持Masking层处理过的输入数据,例如Flatten、AveragePooling1D等等,而其中meanpooling是我需要的一个运算。例如LSTM对每一个序列的输出长度都等于该序列的长度,那么均值运算就只应该除以序列长度,而不是padding后的最长长度。

例如下面这个 3x4 大小的张量,经过补零padding的。我希望做axis=1的meanpooling,则第一行应该是 (10+20)/2,第二行应该是 (10+20+30)/3,第三行应该是 (10+20+30+40)/4。

Keras自定义实现带masking的meanpooling层方式

Keras如何自定义层

在 Keras2.0 版本中(如果你使用的是旧版本请更新),自定义一个层的方法参考这里。具体地,你只要实现三个方法即可。

build(input_shape) : 这是你定义层参数的地方。这个方法必须设self.built = True,可以通过调用super([Layer], self).build()完成。如果这个层没有需要训练的参数,可以不定义。

call(x) : 这里是编写层的功能逻辑的地方。你只需要关注传入call的第一个参数:输入张量,除非你希望你的层支持masking。

compute_output_shape(input_shape) : 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。

下面是一个简单的例子:

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np

class MyLayer(Layer):

 def __init__(self, output_dim, **kwargs):
 self.output_dim = output_dim
 super(MyLayer, self).__init__(**kwargs)

 def build(self, input_shape):
 # Create a trainable weight variable for this layer.
 self.kernel = self.add_weight(name='kernel', 
  shape=(input_shape[1], self.output_dim),
  initializer='uniform',
  trainable=True)
 super(MyLayer, self).build(input_shape) # Be sure to call this somewhere!

 def call(self, x):
 return K.dot(x, self.kernel)

 def compute_output_shape(self, input_shape):
 return (input_shape[0], self.output_dim)

Keras自定义层如何允许masking

观察了一些支持masking的层,发现他们对masking的支持体现在两方面。

在 __init__ 方法中设置 supports_masking=True。

实现一个compute_mask方法,用于将mask传到下一层。

部分层会在call中调用传入的mask。

自定义实现带masking的meanpooling

假设输入是3d的。首先,在__init__方法中设置self.supports_masking = True,然后在call中实现相应的计算。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf

class MyMeanPool(Layer):
 def __init__(self, axis, **kwargs):
 self.supports_masking = True
 self.axis = axis
 super(MyMeanPool, self).__init__(**kwargs)

 def compute_mask(self, input, input_mask=None):
 # need not to pass the mask to next layers
 return None

 def call(self, x, mask=None):
 if mask is not None:
 mask = K.repeat(mask, x.shape[-1])
 mask = tf.transpose(mask, [0,2,1])
 mask = K.cast(mask, K.floatx())
 x = x * mask
 return K.sum(x, axis=self.axis) / K.sum(mask, axis=self.axis)
 else:
 return K.mean(x, axis=self.axis)

 def compute_output_shape(self, input_shape):
 output_shape = []
 for i in range(len(input_shape)):
 if i!=self.axis:
 output_shape.append(input_shape[i])
 return tuple(output_shape)

使用举例:

from keras.layers import Input, Masking
from keras.models import Model
from MyMeanPooling import MyMeanPool

data = [[[10,10],[0, 0 ],[0, 0 ],[0, 0 ]],
 [[10,10],[20,20],[0, 0 ],[0, 0 ]],
 [[10,10],[20,20],[30,30],[0, 0 ]],
 [[10,10],[20,20],[30,30],[40,40]]]

A = Input(shape=[4,2]) # None * 4 * 2
mA = Masking()(A)
out = MyMeanPool(axis=1)(mA)

model = Model(inputs=[A], outputs=[out])

print model.summary()
print model.predict(data)

结果如下,每一行对应一个样本的结果,例如第一个样本只有第一个时刻有值,输出结果是[10. 10. ],是正确的。

[[10. 10.]
 [15. 15.]
 [20. 20.]
 [25. 25.]]

在DeepFM中,每个样本都是由ID构成的,多值field往往会导致样本长度不一的情况,例如interest这样的field,同一个样本可能在该field中有多项取值,毕竟每个人的兴趣点不止一项。

采取padding的方法将每个field的特征补长到最长的长度,则数据尺寸是 [batch_size, max_timestep],经过Embedding为每个样本的每个特征ID配一个latent vector,数据尺寸将变为 [batch_size, max_timestep,latent_dim]。

我们希望每一个field的Embedding之后的尺寸为[batch_size, latent_dim],然后进行concat操作横向拼接,所以这里就可以使用自定义的MeanPool层了。希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之lambda表达式使用方法
Feb 12 Python
python通过加号运算符操作列表的方法
Jul 28 Python
详解Python中使用base64模块来处理base64编码的方法
Jul 01 Python
Python中的FTP通信模块ftplib的用法整理
Jul 08 Python
python基础教程之Filter使用方法
Jan 17 Python
Python3编程实现获取阿里云ECS实例及监控的方法
Aug 18 Python
python3.5绘制随机漫步图
Aug 27 Python
解决pycharm py文件运行后停止按钮变成了灰色的问题
Nov 29 Python
对python3 中方法各种参数和返回值详解
Dec 15 Python
python打印n位数“水仙花数”(实例代码)
Dec 25 Python
Python使用PyQt5/PySide2编写一个极简的音乐播放器功能
Feb 07 Python
利用Python实时获取steam特惠游戏数据
Jun 25 Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
如何在Windows中安装多个python解释器
Jun 16 #Python
You might like
phpBB BBcode处理的漏洞
2006/10/09 PHP
php中的filesystem文件系统函数介绍及使用示例
2014/02/13 PHP
php随机抽奖实例分析
2015/03/04 PHP
PHP利用pdo_odbc实现连接数据库示例【基于ThinkPHP5.1搭建的项目】
2019/05/13 PHP
JS无法捕获滚动条上的mouse up事件的原因猜想
2012/03/21 Javascript
Js注册协议倒计时的小例子
2013/06/24 Javascript
JS 去前后空格大全(IE9亲测)
2013/07/15 Javascript
FireBug 调试JS入门教程 如何调试JS
2013/12/23 Javascript
jQuery获得IE版本不准确webbrowser的解决方法
2014/02/23 Javascript
JQuery中serialize()、serializeArray()和param()方法示例介绍
2014/07/31 Javascript
JavaScript实现同一页面内两个表单互相传值的方法
2015/08/12 Javascript
D3.js中data(), enter() 和 exit()的问题详解
2015/08/17 Javascript
seajs加载jquery时提示$ is not a function该怎么解决
2015/10/23 Javascript
Express实现前端后端通信上传图片之存储数据库(mysql)傻瓜式教程(二)
2015/12/10 Javascript
全面解析标签页的切换方式
2016/08/21 Javascript
扩展jquery easyui tree的搜索树节点方法(推荐)
2016/10/28 Javascript
JScript实现表格的简单操作
2017/08/15 Javascript
JS实现的计数排序与基数排序算法示例
2017/12/04 Javascript
用node开发并发布一个cli工具的方法步骤
2019/01/03 Javascript
Vue 2.0 侦听器 watch属性代码详解
2019/06/19 Javascript
js数组相减简单示例【删除a数组所有与b数组相同元素】
2020/03/04 Javascript
python操作摄像头截图实现远程监控的例子
2014/03/25 Python
Python 字典与字符串的互转实例
2017/01/13 Python
python爬虫的工作原理
2017/03/05 Python
关于 HTML5 的七个传说小结
2012/04/12 HTML / CSS
联想智利官方网站:Lenovo Chile
2020/06/03 全球购物
Chupi官网:在爱尔兰手工制作的订婚、结婚戒指和精美珠宝
2020/09/28 全球购物
护士实习鉴定范文
2013/12/22 职场文书
幼儿教师工作感言
2014/02/14 职场文书
园林设计专业毕业生求职信
2014/03/23 职场文书
2015年幼儿园新年寄语
2014/12/08 职场文书
大学生自荐信范文
2015/03/05 职场文书
2015年信访维稳工作总结
2015/04/07 职场文书
物业工程部主管岗位职责
2015/04/16 职场文书
欠条格式范本
2015/07/03 职场文书
css实现文章分割线样式的多种方法总结
2021/04/21 HTML / CSS