在keras 中获取张量 tensor 的维度大小实例


Posted in Python onJune 10, 2020

在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。

shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,

示例:

>>> from keras import backend as K
>>> tf_session = K.get_session()
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> input = keras.backend.placeholder(shape=(2, 4, 5))
>>> K.shape(kvar)
<tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
>>> K.shape(input)
<tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
__To get integer shape (Instead, you can use K.int_shape(x))__
 
>>> K.shape(kvar).eval(session=tf_session)
array([2, 2], dtype=int32)
>>> K.shape(input).eval(session=tf_session)
array([2, 4, 5], dtype=int32)

如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。

>>> from keras import backend as K
>>> input = K.placeholder(shape=(2, 4, 5))
>>> K.int_shape(input)
(2, 4, 5)
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.int_shape(kvar)
(2, 2)

最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。

补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)

tf.shape(a)和a.get_shape()比较

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

import tensorflow as tf 
import numpy as np 

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]] 
z=np.arange(24).reshape([2,3,4])

sess=tf.Session() 
# tf.shape() 
x_shape=tf.shape(x)          # x_shape 是一个tensor 
y_shape=tf.shape(y)          # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> 
z_shape=tf.shape(z)          # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> 
print(sess.run(x_shape))       # 结果:[2 3]
print(sess.run(y_shape))       # 结果:[2 3]
print(sess.run(z_shape) )       # 结果:[2 3 4]

x_shape=x.get_shape() 
print(x_shape)    # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组                            (2, 3)
x_shape=x.get_shape().as_list() 
print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算
# y_shape=y.get_shape() 
print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() 
print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()

以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的异常处理学习笔记
Jan 28 Python
在Python下利用OpenCV来旋转图像的教程
Apr 16 Python
Python基于ThreadingTCPServer创建多线程代理的方法示例
Jan 11 Python
python的scikit-learn将特征转成one-hot特征的方法
Jul 10 Python
python一键去抖音视频水印工具
Sep 14 Python
使用Python机器学习降低静态日志噪声
Sep 29 Python
Python八皇后问题解答过程详解
Jul 29 Python
python 实现目录复制的三种小结
Dec 04 Python
Python3自定义json逐层解析器代码
May 11 Python
python实现一次性封装多条sql语句(begin end)
Jun 06 Python
使用python对excel表格处理的一些小功能
Jan 25 Python
Python提取PDF指定内容并生成新文件
Jun 09 Python
Keras—embedding嵌入层的用法详解
Jun 10 #Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 #Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 #Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 #Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 #Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 #Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 #Python
You might like
openflashchart 2.0 简单案例php版
2012/05/21 PHP
Zend Framework入门教程之Zend_Registry组件用法详解
2016/12/09 PHP
PHP实现的文件上传类与用法详解
2017/07/05 PHP
PHP与Perl之间知识点区别整理
2019/03/19 PHP
javascript之可拖动的iframe效果代码
2008/08/01 Javascript
网页自动跳转代码收集
2009/09/27 Javascript
JS是按值传递还是按引用传递
2015/01/30 Javascript
使用AngularJS实现表单向导的方法
2015/06/19 Javascript
使用JQuery在线制作ppt并在线演示源码特效
2015/09/08 Javascript
详解AngularJS中$http缓存以及处理多个$http请求的方法
2016/02/06 Javascript
Bootstrap Metronic完全响应式管理模板学习笔记
2016/07/08 Javascript
jQuery EasyUI API 中文帮助文档和扩展实例
2016/08/01 Javascript
AngularJs基于角色的前端访问控制的实现
2016/11/07 Javascript
JS实现全屏的四种写法
2016/12/30 Javascript
基于vue的短信验证码倒计时demo
2017/09/13 Javascript
jQuery使用bind函数实现绑定多个事件的方法
2017/10/11 jQuery
bootstrap 弹出框modal添加垂直方向滚轴效果
2018/07/09 Javascript
微信小程序学习笔记之表单提交与PHP后台数据交互处理图文详解
2019/03/28 Javascript
不可错过的十本Python好书
2017/07/06 Python
在CentOS6上安装Python2.7的解决方法
2018/01/09 Python
python中(str,list,tuple)基础知识汇总
2018/02/20 Python
解决DataFrame排序sort的问题
2018/06/07 Python
Python利用递归实现文件的复制方法
2018/10/27 Python
Python绘制频率分布直方图的示例
2019/07/08 Python
Eclipse配置python默认头过程图解
2020/04/26 Python
Python配置pip国内镜像源的实现
2020/08/20 Python
澳大利亚波西米亚风情网上商店:Czarina
2019/03/18 全球购物
如何在存储过程中使用Loop
2016/01/05 面试题
酒店副总经理岗位职责范本
2014/02/04 职场文书
文化建设工作方案
2014/05/12 职场文书
副总经理任命书
2014/06/05 职场文书
2014年教育培训工作总结
2014/12/08 职场文书
Python 批量下载阴阳师网站壁纸
2021/05/19 Python
浅谈Python中的函数(def)及参数传递操作
2021/05/25 Python
Python捕获、播放和保存摄像头视频并提高视频清晰度和对比度
2022/04/14 Python
Tomcat弱口令复现及利用
2022/05/06 Servers