keras slice layer 层实现方式


Posted in Python onJune 11, 2020

注意的地方: keras中每层的输入输出的tensor是张量, 比如Tensor shape是(N, H, W, C), 对于tf后台, channels_last

Define a slice layer using Lamda layer
def slice(x, h1, h2, w1, w2):
 """ Define a tensor slice function
 """
 return x[:, h1:h2, w1:w2, :]

定义完slice function之后,利用lambda layer添加到定义的网络中去

# Add slice layer
slice_1 = Lambda(slice, arguments={'h1': 0, 'h2': 6, 'w1': 0, 'w2': 6})(sliced)
# As for tensorfow backend, Lambda doesn't need output shape argument
slice_2 = Lambda(slice, arguments={'h1': 0, 'h2': 6, 'w1': 6, 'w2': 12})(sliced)

补充知识:tensorflow和keras张量切片(slice)

Notes

想将一个向量keras slice layer 层实现方式 分割成两部分:keras slice layer 层实现方式 操作大概是:

keras slice layer 层实现方式

在 TensorFlow 中,用 tf.slice 实现张量切片,Keras 中自定义 Lambda 层实现。

TensorFlow

tf.slice(input_, begin, size, name=None)

input_:tf.tensor,被操作的 tensor

begin:list,各个维度的开始下标

size:list,各个维度上要截多长

import tensorflow as tf

with tf.Session() as sess:
 a = tf.constant([1, 2, 3, 4, 5])
 b = tf.slice(a, [0], [2]) # 第一个维度从 0 开始,截 2 个
 c = tf.slice(a, [2], [3]) # 第一个维度从 2 开始,截 3 个
 print(a.eval())
 print(b.eval())
 print(c.eval())

输出

[1 2 3 4 5]
[1 2]
[3 4 5]

Keras

from keras.layers import Lambda
from keras.models import Sequential
import numpy as np

a = np.array([[1, 2, 3, 4, 5]])
model = Sequential([
 Lambda(lambda a: a[:, :2], input_shape=[5]) # 第二维截前 2 个
])

print(model.predict(a))

输出

[[1. 2.]]

以上这篇keras slice layer 层实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的几个常用排序算法实例
Jun 16 Python
Python中使用摄像头实现简单的延时摄影技术
Mar 27 Python
Python中的下划线详解
Jun 24 Python
Python排序搜索基本算法之希尔排序实例分析
Dec 09 Python
Python cookbook(数据结构与算法)找到最大或最小的N个元素实现方法示例
Feb 13 Python
django中静态文件配置static的方法
May 20 Python
详解Django中类视图使用装饰器的方式
Aug 12 Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 Python
给 TensorFlow 变量进行赋值的方式
Feb 10 Python
pycharm 对代码做静态检查操作
Jun 09 Python
学生如何注册Pycharm专业版以及pycharm的安装
Sep 24 Python
基于Python-Pycharm实现的猴子摘桃小游戏(源代码)
Feb 20 Python
Spring http服务远程调用实现过程解析
Jun 11 #Python
keras Lambda自定义层实现数据的切片方式,Lambda传参数
Jun 11 #Python
python怎么提高计算速度
Jun 11 #Python
Python 实现自动登录+点击+滑动验证功能
Jun 10 #Python
Python函数参数定义及传递方式解析
Jun 10 #Python
什么是python类属性
Jun 10 #Python
基于SQLAlchemy实现操作MySQL并执行原生sql语句
Jun 10 #Python
You might like
基于php上传图片重命名的6种解决方法的详细介绍
2013/04/28 PHP
PHP树的深度编历生成迷宫及A*自动寻路算法实例分析
2015/03/10 PHP
PHP 数据结构队列(SplQueue)和优先队列(SplPriorityQueue)简单使用实例
2015/05/12 PHP
PHP 错误处理机制
2015/07/06 PHP
php 二维数组时间排序实现代码
2016/11/19 PHP
thinkPHP5框架分页样式类完整示例
2018/09/01 PHP
php实现的生成排列算法示例
2019/07/25 PHP
javascript基础第一章 JavaScript与用户端
2010/07/22 Javascript
点击弹出层外区域关闭弹出层jquery特效示例
2013/08/25 Javascript
JS不间断向上滚动效果代码
2013/12/25 Javascript
jQuery性能优化的38个建议
2014/03/04 Javascript
手写的一个兼容各种浏览器的javascript getStyle函数(获取元素的样式)
2014/06/06 Javascript
jQuery模仿阿里云购买服务器选择购买时间长度的代码
2016/04/29 Javascript
jQuery给表格添加分页效果
2017/03/02 Javascript
NodeJS如何实现同步的方法示例
2018/08/24 NodeJs
原生JS 实现的input输入时表格过滤操作示例
2019/08/03 Javascript
[05:28]刀塔密之一:团结则存
2014/07/03 DOTA
[39:11]DOTA2上海特级锦标赛C组资格赛#2 LGD VS Newbee第二局
2016/02/28 DOTA
python文字和unicode/ascll相互转换函数及简单加密解密实现代码
2019/08/12 Python
对YOLOv3模型调用时候的python接口详解
2019/08/26 Python
Pytorch基本变量类型FloatTensor与Variable用法
2020/01/08 Python
pymysql之cur.fetchall() 和cur.fetchone()用法详解
2020/05/15 Python
Pandas替换及部分替换(replace)实现流程详解
2020/10/12 Python
Python中正则表达式对单个字符,多个字符和匹配边界等使用
2021/01/27 Python
HTML5之SVG 2D入门1—SVG(可缩放矢量图形)概述
2013/01/30 HTML / CSS
canvas因为图片资源不在同一域名下而导致的跨域污染画布的解决办法
2019/01/18 HTML / CSS
Html5 canvas画图白板踩坑
2020/06/01 HTML / CSS
澳大利亚最超值的自行车之家:Reid Cycles
2019/03/24 全球购物
幼儿园家长评语
2014/02/10 职场文书
大雁塔英文导游词
2015/02/10 职场文书
公司前台接待岗位职责
2015/04/03 职场文书
2015年图书馆个人工作总结
2015/05/26 职场文书
2015暑假实习报告范文
2015/07/13 职场文书
安全生产感想
2015/08/07 职场文书
Vue.Draggable实现交换位置
2022/04/07 Vue.js
SQL Server使用T-SQL语句批处理
2022/05/20 SQL Server