Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中subprocess模块用法实例详解
May 20 Python
Python基于select实现的socket服务器
Apr 13 Python
python使用logging模块发送邮件代码示例
Jan 18 Python
Python3用tkinter和PIL实现看图工具
Jun 21 Python
python实现公司年会抽奖程序
Jan 22 Python
对Python生成汉字字库文字,以及转换为文字图片的实例详解
Jan 29 Python
Python3安装Pillow与PIL的方法
Apr 03 Python
python爬虫selenium和phantomJs使用方法解析
Aug 08 Python
Python装饰器如何实现修复过程解析
Sep 05 Python
python 还原梯度下降算法实现一维线性回归
Oct 22 Python
python 使用paramiko模块进行封装,远程操作linux主机的示例代码
Dec 03 Python
python re模块常见用法例举
Mar 01 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
极典R601SW收音机
2021/03/02 无线电
PHP怎样调用MSSQL的存储过程
2006/10/09 PHP
配置Apache2.2+PHP5+CakePHP1.2+MySQL5运行环境
2009/04/25 PHP
PHP的preg_match匹配字符串长度问题解决方法
2014/05/03 PHP
PHP中ID设置自增后不连续的原因分析及解决办法
2016/08/21 PHP
Yii2语言国际化的配置教程
2018/08/19 PHP
PHP实现无限极分类的两种方式示例【递归和引用方式】
2019/03/25 PHP
php curl操作API接口类完整示例
2019/05/21 PHP
通过隐藏option实现select的联动效果
2009/11/10 Javascript
Javascript学习笔记5 类和对象
2010/01/11 Javascript
使用jquery动态加载javascript以减少服务器压力
2012/10/29 Javascript
javascript实现C语言经典程序题
2015/11/29 Javascript
AngularJs 指令详解及示例代码
2016/09/01 Javascript
关于js函数解释(包括内嵌,对象等)
2016/11/20 Javascript
浅谈Vue响应式(数组变异方法)
2018/05/07 Javascript
jQuery+css last-child实现选择最后一个子元素操作示例
2018/12/10 jQuery
js实现json数组分组合并操作示例
2019/02/12 Javascript
了解javascript中变量及函数的提升
2019/05/27 Javascript
vue使用websocket的方法实例分析
2019/06/22 Javascript
原生Vue 实现右键菜单组件功能
2019/12/16 Javascript
prettier自动格式化去换行的实现代码
2020/08/25 Javascript
Vue 实现监听窗口关闭事件,并在窗口关闭前发送请求
2020/09/01 Javascript
[03:23]我的刀塔你不可能这么可爱 第一期金萌萌的故事
2014/06/20 DOTA
[03:15]DOTA2-DPC中国联赛1月22日Recap集锦
2021/03/11 DOTA
python编写Logistic逻辑回归
2020/12/30 Python
Python学习笔记之文件的读写操作实例分析
2019/08/07 Python
Python中 CSV格式清洗与转换的实例代码
2019/08/29 Python
python列表推导式入门学习解析
2019/12/02 Python
Pytorch对Himmelblau函数的优化详解
2020/02/29 Python
css3media响应式布局实例
2016/07/08 HTML / CSS
HTML5 CSS3实现一个精美VCD包装盒个性幻灯片案例
2014/06/16 HTML / CSS
小天鹅官方商城:LittleSwan
2017/06/16 全球购物
实习报告范文之电话客服岗位
2019/07/26 职场文书
python基于opencv批量生成验证码的示例
2021/04/28 Python
详解Python函数print用法
2021/06/18 Python
create-react-app开发常用配置教程
2022/06/25 Javascript