Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用新浪微博api上传图片到微博示例
Jan 10 Python
解决PyCharm中光标变粗的问题
Aug 05 Python
python实现随机森林random forest的原理及方法
Dec 21 Python
python抓取搜狗微信公众号文章
Apr 01 Python
Python秒算24点实现及原理详解
Jul 29 Python
Python如何使用k-means方法将列表中相似的句子归类
Aug 08 Python
Python类和实例的属性机制原理详解
Mar 21 Python
Python3操作读写CSV文件使用包过程解析
Apr 10 Python
用OpenCV进行年龄和性别检测的实现示例
Jan 29 Python
python实战之一步一步教你绘制小猪佩奇
Apr 22 Python
Python爬虫之爬取二手房信息
Apr 27 Python
Python 居然可以在 Excel 中画画你知道吗
Feb 15 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
《雄兵连》《烈阳天道》真的来了
2020/07/13 国漫
对盗链说再见...
2006/10/09 PHP
PHP __autoload函数(自动载入类文件)的使用方法
2012/02/04 PHP
总结一些js自定义的函数
2006/08/05 Javascript
在textarea中显示html页面的javascript代码
2007/04/20 Javascript
js跨域和ajax 跨域问题的实现思路
2009/09/05 Javascript
jquery.validate使用攻略 第一部
2010/07/01 Javascript
javascript 弹出窗口中是否显示地址栏的实现代码
2011/04/14 Javascript
js中的布尔运算符使用介绍
2013/11/20 Javascript
Javascript实现获取窗口的大小和位置代码分享
2014/12/04 Javascript
使用JS画图之点、线、面
2015/01/12 Javascript
Javascript实现代码折叠功能
2016/08/25 Javascript
js基本算法:冒泡排序,二分查找的简单实例
2016/10/08 Javascript
Bootstrap弹出框modal上层的输入框不能获得焦点问题的解决方法
2016/12/13 Javascript
js限制input只能输入有效的数字(第一个不能是小数点)
2018/09/28 Javascript
node获取客户端ip功能简单示例
2019/08/24 Javascript
vue实现div单选多选功能
2020/07/16 Javascript
ant-design表单处理和常用方法及自定义验证操作
2020/10/27 Javascript
[53:10]Secret vs Pain 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
Python自动化测试工具Splinter简介和使用实例
2014/05/13 Python
python中zip()方法应用实例分析
2016/04/16 Python
详解Python中使用base64模块来处理base64编码的方法
2016/07/01 Python
python机器学习实战之最近邻kNN分类器
2017/12/20 Python
Python 爬虫之Beautiful Soup模块使用指南
2018/07/05 Python
python conda操作方法
2019/09/11 Python
ipad上运行python的方法步骤
2019/10/12 Python
python 插入日期数据到Oracle实例
2020/03/02 Python
HTML5 Canvas 实现圆形进度条并显示数字百分比效果示例
2017/08/18 HTML / CSS
园林设计师自荐信
2013/11/18 职场文书
给同事的道歉信
2014/01/11 职场文书
篝火晚会主持词
2014/03/25 职场文书
计划生育证明格式及范本
2014/10/09 职场文书
小学生禁毒教育心得体会
2016/01/15 职场文书
教师节作文之小学四年级
2019/09/03 职场文书
Oracle中update和select 关联操作
2022/01/18 Oracle
MySQL 原理与优化之Update 优化
2022/08/14 MySQL