Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据结构之图的实现方法
Jul 08 Python
Python实现多线程HTTP下载器示例
Feb 11 Python
python中的计时器timeit的使用方法
Oct 20 Python
Python3实现带附件的定时发送邮件功能
Dec 22 Python
python中的常量和变量代码详解
Jul 25 Python
Python2和Python3中urllib库中urlencode的使用注意事项
Nov 26 Python
通过python爬虫赚钱的方法
Jan 29 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 Python
利用Python的folium包绘制城市道路图的实现示例
Aug 24 Python
Pycharm编辑器功能之代码折叠效果的实现代码
Oct 15 Python
Python更改pip镜像源的方法示例
Dec 01 Python
Python编写nmap扫描工具
Jul 21 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
PHP 将图片按创建时间进行分类存储的实现代码
2010/01/05 PHP
php使用高斯算法实现图片的模糊处理功能示例
2016/11/11 PHP
利用PHP判断文件是否为图片的方法总结
2017/01/06 PHP
安装docker和docker-compose实例详解
2019/07/30 PHP
php模拟post提交请求调用接口示例解析
2020/08/07 PHP
延时重复执行函数 lLoopRun.js
2007/05/08 Javascript
Javascript 表单之间的数据传递代码
2008/12/04 Javascript
Visual Studio中的jQuery智能提示设置方法
2010/03/27 Javascript
jquery二级导航内容均分的原理及实现
2013/08/13 Javascript
js之事件冒泡和事件捕获详细介绍
2013/10/28 Javascript
jquery实现图片灯箱明暗的遮罩效果
2013/11/15 Javascript
JS基于Ajax实现的网页Loading效果代码
2015/10/27 Javascript
javascript bom是什么及bom和dom的区别
2015/11/26 Javascript
BootStrap扔进Django里的方法详解
2016/05/13 Javascript
javascript类型系统——日期Date对象全面了解
2016/07/13 Javascript
js使用Replace结合正则替换重复出现的字符串功能示例
2016/12/27 Javascript
Vue.js实现表格动态增加删除的方法(附源码下载)
2017/01/20 Javascript
AngularJS Toaster使用详解
2017/02/24 Javascript
JS判断一个数是否是水仙花数
2017/06/11 Javascript
iview table render集成switch开关的实例
2018/03/14 Javascript
Js 利用正则表达式和replace函数获取string中所有被匹配到的文本(推荐)
2018/10/28 Javascript
python使用wxpython开发简单记事本的方法
2015/05/20 Python
Python自然语言处理之词干,词形与最大匹配算法代码详解
2017/11/16 Python
深入理解python中sort()与sorted()的区别
2018/08/29 Python
浅谈Python3识别判断图片主要颜色并和颜色库进行对比的方法
2019/10/25 Python
Python垃圾回收机制三种实现方法
2020/04/27 Python
Python使用Paramiko控制liunx第三方库
2020/05/20 Python
Python常用外部指令执行代码实例
2020/11/05 Python
Python列表的深复制和浅复制示例详解
2021/02/12 Python
CSS中的字体大小设置属性总结
2016/05/24 HTML / CSS
浅谈CSS3特性查询(Feature Query: @supports)功能简介
2017/07/31 HTML / CSS
Pat McGrath Labs官网:世界上最有影响力的化妆师推出的彩妆品牌
2018/01/07 全球购物
Lookfantastic台湾:英国彩妆美发保养购物网
2018/03/26 全球购物
大学军训感想
2014/02/12 职场文书
村干部承诺书
2014/03/28 职场文书
简易离婚协议书(范本)
2014/10/25 职场文书