Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现栈的方法
May 26 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
Ubuntu 下 vim 搭建python 环境 配置
Jun 12 Python
django的登录注册系统的示例代码
May 14 Python
Python使用修饰器进行异常日志记录操作示例
Mar 19 Python
python+selenium实现简历自动刷新的示例代码
May 20 Python
在PYQT5中QscrollArea(滚动条)的使用方法
Jun 14 Python
树莓派与PC端在局域网内运用python实现即时通讯
Jun 22 Python
Django ORM 查询管理器源码解析
Aug 05 Python
如何提高python 中for循环的效率
Apr 15 Python
Python中requests做接口测试的方法
May 30 Python
Python制作春联的示例代码
Jan 22 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
我的论坛源代码(十)
2006/10/09 PHP
php下连接ftp实现文件的上传、下载、删除文件实例代码
2010/06/03 PHP
php的declare控制符和ticks教程(附示例)
2014/03/21 PHP
CodeIgniter配置之SESSION用法实例分析
2016/01/19 PHP
laravel 修改.htaccess文件 重定向public的解决方法
2019/10/12 PHP
关于可运行代码无法正常执行的使用说明
2010/05/13 Javascript
基于Jquery的淡入淡出的特效基础练习
2010/12/13 Javascript
定时器(setTimeout/setInterval)调用带参函数失效解决方法
2013/03/26 Javascript
js Array对象的扩展函数代码
2013/04/24 Javascript
js中的this关键字详解
2013/09/25 Javascript
javascript简单实现图片预加载
2014/12/03 Javascript
JS模式之单例模式基本用法
2015/06/30 Javascript
js弹出对话框方式小结
2015/11/17 Javascript
详解为Angular.js内置$http服务添加拦截器的方法
2016/12/20 Javascript
Html5 js实现手风琴效果
2020/04/17 Javascript
vue2.0父子组件及非父子组件之间的通信方法
2017/01/21 Javascript
vue时间格式化实例代码
2017/06/13 Javascript
vue实现移动端图片裁剪上传功能
2020/08/18 Javascript
Node.js 使用递归实现遍历文件夹中所有文件
2017/09/18 Javascript
Vue DevTools调试工具的使用
2017/12/05 Javascript
脚手架vue-cli工程webpack的作用和特点
2018/09/29 Javascript
vue悬浮可拖拽悬浮按钮的实例代码
2019/08/20 Javascript
微信小程序实现蒙版弹出窗功能
2019/09/17 Javascript
Vue 实现简易多行滚动"弹幕"效果
2020/01/02 Javascript
[39:18]完美世界DOTA2联赛PWL S3 Forest vs LBZS 第二场 12.17
2020/12/19 DOTA
Windows下安装python2和python3多版本教程
2017/03/30 Python
python使用锁访问共享变量实例解析
2018/02/08 Python
html5跳转小程序wx-open-launch-weapp踩坑
2020/12/02 HTML / CSS
Prototype是怎么扩展DOM的
2014/10/01 面试题
大学毕业生工作的自我评价
2013/10/01 职场文书
我爱幼儿园演讲稿
2014/09/11 职场文书
情侣之间的道歉短信
2015/05/12 职场文书
青春雷锋观后感
2015/06/10 职场文书
Linux7.6二进制安装Mysql8.0.27详细操作步骤
2021/11/27 MySQL
解析MySQL索引的作用
2022/03/03 MySQL
Java实现HTML转为Word的示例代码
2022/06/28 Java/Android