Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例详解Python中yield生成器的用法
Aug 05 Python
Python搜索引擎实现原理和方法
Nov 27 Python
详解Python安装scrapy的正确姿势
Jun 26 Python
python 生成图形验证码的方法示例
Nov 11 Python
Python设计模式之外观模式实例详解
Jan 17 Python
python实现爬山算法的思路详解
Apr 09 Python
对django views中 request, response的常用操作详解
Jul 17 Python
django中上传图片分页三级联动效果的实现代码
Aug 30 Python
Python使用turtle库绘制小猪佩奇(实例代码)
Jan 16 Python
Ubuntu18.04安装 PyCharm并使用 Anaconda 管理的Python环境
Apr 08 Python
Python tkinter制作单机五子棋游戏
Sep 14 Python
详解Python Celery和RabbitMQ实战教程
Jan 20 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
PHP 类商品秒杀计时实现代码
2010/05/05 PHP
php实现快速排序法函数代码
2012/08/27 PHP
PHP动态编译出现Cannot find autoconf的解决方法
2014/11/05 PHP
浅谈PHP实现大流量下抢购方案
2017/12/15 PHP
PHP 访问数据库配置通用方法(json)
2018/05/20 PHP
清空上传控件input file的值
2010/07/03 Javascript
JavaScript中判断函数是new还是()调用的区别说明
2011/04/07 Javascript
使用javascipt---实现二分查找法
2013/04/10 Javascript
关于js内存泄露的一个好例子
2013/12/09 Javascript
JavaScript onkeypress事件入门实例(按下或按住一个键盘按键)
2014/10/17 Javascript
jqueryUI里拖拽排序示例分析
2015/02/26 Javascript
javascript基础知识
2016/06/07 Javascript
浅析Ajax语法
2016/12/05 Javascript
详解Angular2 关于*ngFor 嵌套循环
2017/05/22 Javascript
浅谈vue路径优化之resolve
2017/10/13 Javascript
Angular4集成ng2-file-upload的上传组件
2018/03/14 Javascript
JS实现的DOM插入节点操作示例
2018/04/04 Javascript
Vue2.0仿饿了么webapp单页面应用详细步骤
2018/07/08 Javascript
基于vue+axios+lrz.js微信端图片压缩上传方法
2019/06/25 Javascript
JS数组方法reverse()用法实例分析
2020/01/18 Javascript
Vue中避免滥用this去读取data中数据
2021/03/02 Vue.js
[05:56]第十六期——新进3大C之小兔基
2014/06/24 DOTA
[30:55]完美世界DOTA2联赛PWL S2 Magma vs LBZS 第二场 11.18
2020/11/18 DOTA
[50:54]完美世界DOTA2联赛 GXR vs IO 第三场 11.07
2020/11/10 DOTA
深入理解Python中命名空间的查找规则LEGB
2015/08/06 Python
Python彩色化Linux的命令行终端界面的代码实例分享
2016/07/02 Python
Python面向对象程序设计之私有属性及私有方法示例
2019/04/08 Python
使用Python爬取小姐姐图片(beautifulsoup法)
2021/02/11 Python
adidas菲律宾官网:adidas PH
2020/02/07 全球购物
施惠特软件测试面试题以及笔试题
2015/05/13 面试题
六一儿童节活动策划方案
2014/01/27 职场文书
学生犯错保证书
2015/05/09 职场文书
爱国主义电影观后感
2015/06/18 职场文书
详解Python描述符的工作原理
2021/06/11 Python
Canvas如何做个雪花屏版404的实现
2021/09/25 HTML / CSS
Jmerte 分布式压测及分布式压测配置
2022/04/30 Java/Android