Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现简单的TCP代理服务器
Oct 08 Python
浅谈Python的异常处理
Jun 19 Python
Python判断文件或文件夹是否存在的三种方法
Jul 27 Python
Python中的is和==比较两个对象的两种方法
Sep 06 Python
python书籍信息爬虫实例
Mar 19 Python
Django使用HttpResponse返回图片并显示的方法
May 22 Python
python 搜索大文件的实例代码
Jul 08 Python
pandas按行按列遍历Dataframe的几种方式
Oct 23 Python
Python更换pip源方法过程解析
May 19 Python
Python3爬虫中Splash的知识总结
Jul 10 Python
python模块内置属性概念及实例
Feb 18 Python
opencv实现图像几何变换
Mar 24 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
Protoss建筑一览
2020/03/14 星际争霸
支持php4、php5的mysql数据库操作类
2008/01/10 PHP
php更新mysql后获取影响的行数发生异常解决方法
2013/03/28 PHP
PHP实现的最大正向匹配算法示例
2017/12/19 PHP
jQuery 关于伪类选择符的使用说明
2013/04/24 Javascript
Jquery实现的tab效果可以指定默认显示第几页
2013/10/16 Javascript
node.js中的console.trace方法使用说明
2014/12/09 Javascript
javascript实现动态表头及表列的展现方法
2015/07/14 Javascript
jquery实现最简单的滑动菜单效果代码
2015/09/12 Javascript
分享一些常用的jQuery动画事件和动画函数
2015/11/27 Javascript
JavaScript正则表达式校验与递归函数实际应用实例解析
2017/08/04 Javascript
Three.js中网格对象MESH的属性与方法详解
2017/09/27 Javascript
使用Vue.js和Flask来构建一个单页的App的示例
2018/03/21 Javascript
Bootstrap table表格初始化表格数据的方法
2018/07/25 Javascript
基于webpack4.X从零搭建React脚手架的方法步骤
2018/12/23 Javascript
python面向对象_详谈类的继承与方法的重载
2017/06/07 Python
python 实现求解字符串集的最长公共前缀方法
2018/07/20 Python
在Python中输入一个以空格为间隔的数组方法
2018/11/13 Python
详解python做UI界面的方法
2019/02/27 Python
Falsk 与 Django 过滤器的使用与区别详解
2019/06/04 Python
Python 单例设计模式用法实例分析
2019/09/23 Python
ipad上运行python的方法步骤
2019/10/12 Python
python opencv将表格图片按照表格框线分割和识别
2019/10/30 Python
python 多线程中join()的作用
2020/10/29 Python
基于Python的接口自动化unittest测试框架和ddt数据驱动详解
2021/01/27 Python
英国的屈臣氏:Boots博姿
2017/12/23 全球购物
应届生护士求职信
2013/11/01 职场文书
副主任竞聘演讲稿
2014/08/18 职场文书
税务干部群众路线教育实践活动自我剖析材料
2014/09/21 职场文书
电子银行业务授权委托书
2014/10/10 职场文书
2014年高数考试作弊检讨书
2014/12/14 职场文书
2014初中数学教研组工作总结
2014/12/19 职场文书
大连星海广场导游词
2015/02/10 职场文书
2015年审计人员工作总结
2015/05/26 职场文书
公司财务制度:成本管理控制制度模板
2019/11/19 职场文书
Python多线程实用方法以及共享变量资源竞争问题
2022/04/12 Python