Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python BeautifulSoup中文乱码问题的2种解决方法
Apr 22 Python
python中pygame模块用法实例
Oct 09 Python
python获得两个数组交集、并集、差集的方法
Mar 27 Python
python函数式编程学习之yield表达式形式详解
Mar 25 Python
python K近邻算法的kd树实现
Sep 06 Python
python3.x实现base64加密和解密
Mar 28 Python
使用pandas读取文件的实现
Jul 31 Python
python如何实现复制目录到指定目录
Feb 13 Python
基于Python把网站域名解析成ip地址
May 25 Python
使用Dajngo 通过代码添加xadmin用户和权限(组)
Jul 03 Python
Python如何将字符串转换为日期
Jul 31 Python
Python 的 sum() Pythonic 的求和方法详细
Oct 16 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
PHP读取ACCESS数据到MYSQL的代码
2011/05/11 PHP
sphinx增量索引的一个问题
2011/06/14 PHP
yii2中的rules 自定义验证规则详解
2016/04/19 PHP
Joomla调用系统自带编辑器的实现方法
2016/05/05 PHP
潜说js对象和数组
2011/05/25 Javascript
javascript删除数组元素并且数组长度减小的简单实例
2014/02/14 Javascript
Bootstrap基础学习
2015/06/16 Javascript
使用requestAnimationFrame实现js动画性能好
2015/08/06 Javascript
jQuery实现仿百度帖吧头部固定导航效果
2015/08/07 Javascript
js简单实现图片延迟加载的方法
2016/07/19 Javascript
Js动态设置rem来实现移动端字体的自适应代码
2016/10/14 Javascript
微信公众号 摇一摇周边功能开发
2016/12/08 Javascript
jQuery实现select模糊查询(反射机制)
2017/01/14 Javascript
详解Vue 事件驱动和依赖追踪
2017/04/22 Javascript
Vue.js项目部署到服务器的详细步骤
2017/07/17 Javascript
在vue项目中集成graphql(vue-ApolloClient)
2018/09/08 Javascript
angular4自定义表单控件[(ngModel)]的实现
2018/11/23 Javascript
使用Vue 自定义文件选择器组件的实例代码
2020/03/04 Javascript
聊聊vue 中的v-on参数问题
2021/01/29 Vue.js
python基础教程之基本数据类型和变量声明介绍
2014/08/29 Python
21行Python代码实现拼写检查器
2016/01/25 Python
Python 数值区间处理_对interval 库的快速入门详解
2018/11/16 Python
Python Pexpect库的简单使用方法
2019/01/29 Python
使用Python-OpenCV向图片添加噪声的实现(高斯噪声、椒盐噪声)
2019/05/28 Python
Python logging日志库空间不足问题解决
2020/09/14 Python
如何基于Python pygame实现动画跑马灯
2020/11/18 Python
css3实现超炫风车特效
2014/11/12 HTML / CSS
HTML5之SVG 2D入门9—蒙板及mask元素介绍与应用
2013/01/30 HTML / CSS
英语硕士生求职简历的自我评价
2013/10/15 职场文书
求职信范文英文版
2014/01/05 职场文书
物流专业大学的自我评价
2014/01/11 职场文书
行政执法队伍作风整顿剖析材料
2014/10/11 职场文书
2014年乡镇民政工作总结
2014/12/02 职场文书
教师工作表现评语
2014/12/31 职场文书
门卫岗位职责
2015/02/09 职场文书
品德与社会教学反思
2016/02/24 职场文书