在keras 中获取张量 tensor 的维度大小实例


Posted in Python onJune 10, 2020

在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。

shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,

示例:

>>> from keras import backend as K
>>> tf_session = K.get_session()
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> input = keras.backend.placeholder(shape=(2, 4, 5))
>>> K.shape(kvar)
<tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
>>> K.shape(input)
<tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
__To get integer shape (Instead, you can use K.int_shape(x))__
 
>>> K.shape(kvar).eval(session=tf_session)
array([2, 2], dtype=int32)
>>> K.shape(input).eval(session=tf_session)
array([2, 4, 5], dtype=int32)

如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。

>>> from keras import backend as K
>>> input = K.placeholder(shape=(2, 4, 5))
>>> K.int_shape(input)
(2, 4, 5)
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.int_shape(kvar)
(2, 2)

最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。

补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)

tf.shape(a)和a.get_shape()比较

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

import tensorflow as tf 
import numpy as np 

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]] 
z=np.arange(24).reshape([2,3,4])

sess=tf.Session() 
# tf.shape() 
x_shape=tf.shape(x)          # x_shape 是一个tensor 
y_shape=tf.shape(y)          # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> 
z_shape=tf.shape(z)          # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> 
print(sess.run(x_shape))       # 结果:[2 3]
print(sess.run(y_shape))       # 结果:[2 3]
print(sess.run(z_shape) )       # 结果:[2 3 4]

x_shape=x.get_shape() 
print(x_shape)    # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组                            (2, 3)
x_shape=x.get_shape().as_list() 
print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算
# y_shape=y.get_shape() 
print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() 
print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()

以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python调用cmd复制文件代码分享
Dec 27 Python
python使用pil生成缩略图的方法
Mar 26 Python
Python中用startswith()函数判断字符串开头的教程
Apr 07 Python
python写入并获取剪切板内容的实例
May 31 Python
pandas.DataFrame选取/排除特定行的方法
Jul 03 Python
对Python 3.5拼接列表的新语法详解
Nov 08 Python
漂亮的Django Markdown富文本app插件的实现
Jan 02 Python
python tornado修改log输出方式
Nov 18 Python
python几种常用功能实现代码实例
Dec 25 Python
Django之choices选项和富文本编辑器的使用详解
Apr 01 Python
Python 爬取淘宝商品信息栏目的实现
Feb 06 Python
方法汇总:Python 安装第三方库常用
Apr 26 Python
Keras—embedding嵌入层的用法详解
Jun 10 #Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 #Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 #Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 #Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 #Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 #Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 #Python
You might like
【星际争霸1】人族1v7家ZBath
2020/03/04 星际争霸
mysql5的sql文件导入到mysql4的方法
2008/10/19 PHP
php使用指定编码导出mysql数据到csv文件的方法
2015/03/31 PHP
JS event使用方法详解
2008/04/28 Javascript
简单的前端js+ajax 购物车框架(入门篇)
2011/10/29 Javascript
JS循环遍历JSON数据的方法
2014/07/08 Javascript
基于javascript、ajax、memcache和PHP实现的简易在线聊天室
2015/02/03 Javascript
javascript获得当前的信息的一些常用命令
2015/02/25 Javascript
JavaScript实现随机替换图片的方法
2015/04/16 Javascript
JS+CSS实现自适应选项卡宽度的圆角滑动门效果
2015/09/15 Javascript
JavaScript中的ParseInt(&quot;08&quot;)和“09”返回0的原因分析及解决办法
2016/05/19 Javascript
使用伪命名空间封装保护独自创建的对象方法
2016/08/04 Javascript
Bootstrap 网站实例之单页营销网站
2016/10/20 Javascript
vue.js数据绑定的方法(单向、双向和一次性绑定)
2017/07/13 Javascript
JavaScript对JSON数据进行排序和搜索
2017/07/24 Javascript
详解关于微信setData回调函数中的坑
2019/02/18 Javascript
微信小程序使用蓝牙小插件
2019/09/23 Javascript
微信小程序日历插件代码实例
2019/12/04 Javascript
ant-design-vue 实现表格内部字段验证功能
2019/12/16 Javascript
微信小程序订阅消息(java后端实现)开发
2020/06/01 Javascript
使用Python的Zato发送AMQP消息的教程
2015/04/16 Python
Python之自动获取公网IP的实例讲解
2017/10/01 Python
Python之列表实现栈的工作功能
2019/01/28 Python
Pycharm 文件更改目录后,执行路径未更新的解决方法
2019/07/19 Python
Python sys模块常用方法解析
2020/02/20 Python
Python多进程编程multiprocessing代码实例
2020/03/12 Python
Pycharm配置PyQt5环境的教程
2020/04/02 Python
在python下实现word2vec词向量训练与加载实例
2020/06/09 Python
python Matplotlib数据可视化(2):详解三大容器对象与常用设置
2020/09/30 Python
nohup的用法
2014/08/10 面试题
培训主管的岗位职责
2013/11/23 职场文书
名企HR怎样看待求职信
2014/02/23 职场文书
空中乘务员岗位职责
2014/03/08 职场文书
vue使用Google Recaptcha验证的实现示例
2021/08/23 Vue.js
十大冰系宝可梦排名,颜值最高的阿罗拉九尾,第三使用率第一
2022/03/18 日漫
MySQL 自动填充 create_time 和 update_time
2022/05/20 MySQL