Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python3中的gettext模块翻译Python源码以支持多语言
Mar 31 Python
python自动裁剪图像代码分享
Nov 25 Python
基于python的多进程共享变量正确打开方式
Apr 28 Python
Python 获取div标签中的文字实例
Dec 20 Python
python hough变换检测直线的实现方法
Jul 12 Python
django框架中ajax的使用及避开CSRF 验证的方式详解
Dec 11 Python
python将图片转base64,实现前端显示
Jan 09 Python
pytorch数据预处理错误的解决
Feb 20 Python
Python标准库shutil模块使用方法解析
Mar 10 Python
Python爬虫爬取电影票房数据及图表展示操作示例
Mar 27 Python
在keras里面实现计算f1-score的代码
Jun 15 Python
Django给表单添加honeypot验证增加安全性
May 06 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
PHP中利用substr_replace将指定两位置之间的字符替换为*号
2011/01/27 PHP
In Javascript Class, how to call the prototype method.(three method)
2007/01/09 Javascript
javascript 强制刷新页面的实现代码
2009/12/13 Javascript
javascript 同时在IE和FireFox获取KeyCode的代码
2010/02/07 Javascript
深入理解JavaScript定时机制
2010/10/29 Javascript
分享一个用Mootools写的鼠标滑过进度条改变进度值的实现代码
2011/12/12 Javascript
jQuery学习笔记之控制页面实现代码
2012/02/27 Javascript
javascript类型转换示例
2014/04/29 Javascript
绑定回车enter事件代码
2014/05/18 Javascript
Jquery Post处理后不进入回调的原因及解决方法
2014/07/15 Javascript
jQuery实现Meizu魅族官方网站的导航菜单效果
2015/09/14 Javascript
Jquery Easyui验证组件ValidateBox使用详解(20)
2016/12/18 Javascript
jQuery解析返回的xml和json方法详解
2017/01/05 Javascript
利用Javascript实现简单的转盘抽奖
2017/02/13 Javascript
bootstrap suggest搜索建议插件使用详解
2017/03/25 Javascript
JavaScript实现简单的星星评分效果
2017/05/18 Javascript
JavaScript Drum Kit 指南(纯 JS 模拟敲鼓效果)
2017/07/23 Javascript
mui back 返回刷新页面的实例
2017/12/06 Javascript
基于vue cli重构多页面脚手架过程详解
2018/01/23 Javascript
微信小程序地图绘制线段并且测量(实例代码)
2020/01/02 Javascript
[58:59]完美世界DOTA2联赛PWL S3 access vs CPG 第一场 12.13
2020/12/16 DOTA
使用Python脚本实现批量网站存活检测遇到问题及解决方法
2016/10/11 Python
浅析Python函数式编程
2018/10/06 Python
Pycharm无法显示动态图片的解决方法
2018/10/28 Python
详解Python3 对象组合zip()和回退方式*zip
2019/05/15 Python
Django中间件拦截未登录url实例详解
2019/09/03 Python
Python使用Tkinter实现转盘抽奖器的步骤详解
2020/01/06 Python
python如何快速拼接字符串
2020/10/28 Python
python help函数实例用法
2020/12/06 Python
html5中如何将图片的绝对路径转换成文件对象
2018/01/11 HTML / CSS
蔻驰美国官网:COACH美国
2016/08/18 全球购物
SmartBuyGlasses比利时:购买品牌太阳镜和眼镜
2019/08/09 全球购物
刚毕业大学生自荐信范文
2014/02/20 职场文书
财务主管岗位职责
2014/02/28 职场文书
2015中学学校工作总结
2015/07/20 职场文书
幼儿园心得体会范文
2016/01/21 职场文书