在keras 中获取张量 tensor 的维度大小实例


Posted in Python onJune 10, 2020

在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。

shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,

示例:

>>> from keras import backend as K
>>> tf_session = K.get_session()
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> input = keras.backend.placeholder(shape=(2, 4, 5))
>>> K.shape(kvar)
<tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
>>> K.shape(input)
<tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
__To get integer shape (Instead, you can use K.int_shape(x))__
 
>>> K.shape(kvar).eval(session=tf_session)
array([2, 2], dtype=int32)
>>> K.shape(input).eval(session=tf_session)
array([2, 4, 5], dtype=int32)

如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。

>>> from keras import backend as K
>>> input = K.placeholder(shape=(2, 4, 5))
>>> K.int_shape(input)
(2, 4, 5)
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.int_shape(kvar)
(2, 2)

最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。

补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)

tf.shape(a)和a.get_shape()比较

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

import tensorflow as tf 
import numpy as np 

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]] 
z=np.arange(24).reshape([2,3,4])

sess=tf.Session() 
# tf.shape() 
x_shape=tf.shape(x)          # x_shape 是一个tensor 
y_shape=tf.shape(y)          # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> 
z_shape=tf.shape(z)          # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> 
print(sess.run(x_shape))       # 结果:[2 3]
print(sess.run(y_shape))       # 结果:[2 3]
print(sess.run(z_shape) )       # 结果:[2 3 4]

x_shape=x.get_shape() 
print(x_shape)    # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组                            (2, 3)
x_shape=x.get_shape().as_list() 
print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算
# y_shape=y.get_shape() 
print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() 
print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()

以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现简单socket程序在两台电脑之间传输消息的方法
Mar 13 Python
Python实现给qq邮箱发送邮件的方法
May 28 Python
python Django模板的使用方法
Jan 14 Python
详解python如何调用C/C++底层库与互相传值
Aug 10 Python
Python+MongoDB自增键值的简单实现
Nov 04 Python
Python自动生产表情包
Mar 17 Python
Python字典实现简单的三级菜单(实例讲解)
Jul 31 Python
Python数字图像处理之霍夫线变换实现详解
Jan 12 Python
安装python3的时候就是输入python3死活没有反应的解决方法
Jan 24 Python
使用Python设计一个代码统计工具
Apr 04 Python
Python2和3字符编码的区别知识点整理
Aug 08 Python
python操作docx写入内容,并控制文本的字体颜色
Feb 13 Python
Keras—embedding嵌入层的用法详解
Jun 10 #Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 #Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 #Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 #Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 #Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 #Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 #Python
You might like
PHP中用hash实现的数组
2011/07/17 PHP
PHP的Json中文处理解决方案
2016/09/29 PHP
PHP fclose函数用法总结
2019/02/15 PHP
解决laravel id非自增 模型取回为0 的问题
2019/10/11 PHP
JavaScript 语言的递归编程
2010/05/18 Javascript
JavaScript单元测试ABC
2012/04/12 Javascript
用JavaScript获取DOM元素位置和尺寸大小的方法
2013/04/12 Javascript
Node.js和MongoDB实现简单日志分析系统
2015/04/25 Javascript
用js实现博客打赏功能
2016/10/24 Javascript
sublime text配置node.js调试(图文教程)
2017/11/23 Javascript
自定义Vue组件打包、发布到npm及使用教程
2019/05/22 Javascript
vue实现文字加密功能
2019/09/27 Javascript
JS继承定义与使用方法简单示例
2020/02/19 Javascript
vue实现列表滚动的过渡动画
2020/06/29 Javascript
原生JS实现相邻月份日历
2020/10/13 Javascript
微信小程序弹窗禁止页面滚动的实现代码
2020/12/30 Javascript
nestjs返回给前端数据格式的封装实现
2021/02/22 Javascript
[04:07]显微镜下的DOTA2第八期——英雄复活动作
2014/06/24 DOTA
[48:51]完美世界DOTA2联赛PWL S2 Magma vs InkIce 第一场 11.28
2020/12/02 DOTA
python判断端口是否打开的实现代码
2013/02/10 Python
Python语言技巧之三元运算符使用介绍
2013/03/04 Python
Python __setattr__、 __getattr__、 __delattr__、__call__用法示例
2015/03/06 Python
python实现贪吃蛇游戏
2020/03/21 Python
基于python3的socket聊天编程
2020/02/17 Python
Python实现寻找回文数字过程解析
2020/06/09 Python
Shopty西班牙:缝纫机在线销售
2018/01/26 全球购物
法国发饰品牌:Alexandre De Paris
2018/12/04 全球购物
俄罗斯极限运动网上商店:Board Shop №1
2020/12/18 全球购物
教师旷工检讨书
2014/01/18 职场文书
幼儿园元旦亲子活动方案
2014/02/17 职场文书
新农村建设标语
2014/06/24 职场文书
党员组织生活会发言材料
2014/10/17 职场文书
道德与公民自我评价
2015/03/09 职场文书
新学期感想
2015/08/10 职场文书
教师学习心得体会范文
2016/01/21 职场文书
Mysql如何查看是否使用到索引
2022/12/24 MySQL