Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中encode()方法的使用简介
May 18 Python
在Django的视图中使用数据库查询的方法
Jul 16 Python
Python使用Paramiko模块编写脚本进行远程服务器操作
May 05 Python
Python做文本按行去重的实现方法
Oct 19 Python
Python批量合并有合并单元格的Excel文件详解
Apr 05 Python
python矩阵转换为一维数组的实例
Jun 05 Python
利用 PyCharm 实现本地代码和远端的实时同步功能
Mar 23 Python
Python requests HTTP验证登录实现流程
Nov 05 Python
python 下载m3u8视频的示例代码
Nov 11 Python
python 三种方法提取pdf中的图片
Feb 07 Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
Jun 05 Python
Python访问Redis的详细操作
Jun 26 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
php桌面中心(三) 修改数据库
2007/03/11 PHP
php面向对象全攻略 (九)访问类型
2009/09/30 PHP
ThinkPHP设置禁止百度等搜索引擎转码(简单实用)
2016/02/15 PHP
php简单构造json多维数组的方法示例
2017/06/08 PHP
php empty 函数判断结果为空但实际值却为非空的原因解析
2018/05/28 PHP
PHP获取本周所有日期或者最近七天所有日期的方法
2018/06/20 PHP
禁止js文件缓存的代码
2010/04/09 Javascript
javascript整除实现代码
2010/11/23 Javascript
Javascript实现页面跳转的几种方式分享
2013/10/26 Javascript
css3元素简单的闪烁效果实现(html5 jquery)
2013/12/28 Javascript
手机开发必备技巧:javascript及CSS功能代码分享
2015/05/25 Javascript
jquery+ajax实现直接提交表单实例分析
2016/06/17 Javascript
提高JavaScript执行效率的23个实用技巧
2017/03/01 Javascript
源码分析Vue.js的监听实现教程
2017/04/23 Javascript
解决Webpack 热部署检测不到文件变化的问题
2018/02/22 Javascript
angular学习之动态创建表单的方法
2018/12/07 Javascript
antd Upload 文件上传的示例代码
2018/12/14 Javascript
Openlayers实现测量功能
2020/09/25 Javascript
[02:40]DOTA2殁境神蚀者 英雄基础教程
2013/11/26 DOTA
Python函数嵌套实例
2014/09/23 Python
python人民币小写转大写辅助工具
2018/06/20 Python
python批量修改ssh密码的实现
2019/08/08 Python
浅谈matplotlib.pyplot与axes的关系
2020/03/06 Python
计算机科学与技术应届生求职信
2013/11/07 职场文书
养殖行业的创业计划书
2014/01/05 职场文书
领导证婚人证婚词
2014/01/13 职场文书
父亲追悼会答谢词
2014/01/17 职场文书
水果超市创业计划书
2014/01/27 职场文书
清洁工岗位职责
2014/01/29 职场文书
报纸媒体创意广告词
2014/03/17 职场文书
董事长助理工作职责
2014/06/08 职场文书
东京审判观后感
2015/06/01 职场文书
初中政治教师教学反思
2016/02/23 职场文书
解决Go gorm踩过的坑
2021/04/30 Golang
十大公认最好看的动漫:《咒术回战》在榜,《钢之炼金术师》第一
2022/03/18 日漫
MySQL 外连接语法之 OUTER JOIN
2022/04/09 MySQL