Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中zip()方法应用实例分析
Apr 16 Python
Python之str操作方法(详解)
Jun 19 Python
python爬虫_自动获取seebug的poc实例
Aug 05 Python
使用django-crontab实现定时任务的示例
Feb 26 Python
Python3.6实现连接mysql或mariadb的方法分析
May 18 Python
Python3实现转换Image图片格式
Jun 21 Python
Python函数参数类型及排序原理总结
Dec 19 Python
Python生成词云的实现代码
Jan 14 Python
Python内建序列通用操作6种实现方法
Mar 26 Python
Python自定义聚合函数merge与transform区别详解
May 26 Python
手把手教你配置JupyterLab 环境的实现
Feb 02 Python
基于PyQT5制作一个桌面摸鱼工具
Feb 15 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
php文件压缩之PHPZip类用法实例
2015/06/18 PHP
详解如何在云服务器上部署Laravel
2017/06/30 PHP
phpStorm2020 注册码
2020/09/17 PHP
JQuery 确定css方框模型(盒模型Box Model)
2010/01/22 Javascript
jquery $.getJSON()跨域请求
2011/12/21 Javascript
基于jQuery通过jQuery.form.js插件实现异步上传
2015/12/13 Javascript
jquery获取select选中值的方法分析
2015/12/22 Javascript
jQuery点击输入框显示验证码图片
2016/05/19 Javascript
自己动手制作基于jQuery的Web页面加载进度条插件
2016/06/03 Javascript
jQuery实现响应鼠标事件的图片透明效果【附demo源码下载】
2016/06/16 Javascript
jQuery简单实现tab选项卡切换效果
2016/06/20 Javascript
Vue.js开发环境搭建
2016/11/10 Javascript
AngularJS的依赖注入实例分析(使用module和injector)
2017/01/19 Javascript
利用node.js实现反向代理的方法详解
2017/07/24 Javascript
Angularjs中的验证input输入框只能输入数字和小数点的写法(推荐)
2017/08/16 Javascript
vue-cli构建项目使用 less的方法
2017/10/04 Javascript
vue实现商城上货组件简易版
2017/11/27 Javascript
JS基于for语句编写的九九乘法表示例
2018/01/04 Javascript
vue mounted组件的使用
2018/06/18 Javascript
利用node 判断打开的是文件 还是 文件夹的实例
2019/06/10 Javascript
js实现for循环跳过undefined值示例
2019/07/02 Javascript
layer弹出框确定前验证:弹出消息框的方法(弹出两个layer)
2019/09/21 Javascript
小程序怎样让wx.navigateBack更好用的方法实现
2019/11/01 Javascript
[04:29]2016国际邀请赛中国区预选赛Ehome战队教练采访
2016/06/27 DOTA
简单的Apache+FastCGI+Django配置指南
2015/07/22 Python
python常用知识梳理(必看篇)
2017/03/23 Python
Python二进制文件读取并转换为浮点数详解
2019/06/25 Python
Python实现弹球小游戏
2020/08/01 Python
奥地利时尚、美容、玩具和家居之家:Kastner & Öhler
2020/04/26 全球购物
应届毕业生求职信范文
2013/12/18 职场文书
客户表扬信范文
2014/01/10 职场文书
汇源肾宝广告词
2014/03/20 职场文书
医院节能减排方案
2014/06/13 职场文书
大学活动总结模板
2014/07/10 职场文书
法定代表人身份证明书(含说明)
2014/10/02 职场文书
angular异步验证器防抖实例详解
2022/03/31 Javascript