在keras 中获取张量 tensor 的维度大小实例


Posted in Python onJune 10, 2020

在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。

shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,

示例:

>>> from keras import backend as K
>>> tf_session = K.get_session()
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> input = keras.backend.placeholder(shape=(2, 4, 5))
>>> K.shape(kvar)
<tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
>>> K.shape(input)
<tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
__To get integer shape (Instead, you can use K.int_shape(x))__
 
>>> K.shape(kvar).eval(session=tf_session)
array([2, 2], dtype=int32)
>>> K.shape(input).eval(session=tf_session)
array([2, 4, 5], dtype=int32)

如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。

>>> from keras import backend as K
>>> input = K.placeholder(shape=(2, 4, 5))
>>> K.int_shape(input)
(2, 4, 5)
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.int_shape(kvar)
(2, 2)

最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。

补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)

tf.shape(a)和a.get_shape()比较

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

import tensorflow as tf 
import numpy as np 

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]] 
z=np.arange(24).reshape([2,3,4])

sess=tf.Session() 
# tf.shape() 
x_shape=tf.shape(x)          # x_shape 是一个tensor 
y_shape=tf.shape(y)          # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> 
z_shape=tf.shape(z)          # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> 
print(sess.run(x_shape))       # 结果:[2 3]
print(sess.run(y_shape))       # 结果:[2 3]
print(sess.run(z_shape) )       # 结果:[2 3 4]

x_shape=x.get_shape() 
print(x_shape)    # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组                            (2, 3)
x_shape=x.get_shape().as_list() 
print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算
# y_shape=y.get_shape() 
print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() 
print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()

以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现目录树生成示例
Mar 28 Python
深入解析Python中函数的参数与作用域
Mar 20 Python
Python查找第n个子串的技巧分享
Jun 27 Python
用Python shell简化开发
Aug 08 Python
python3.x实现base64加密和解密
Mar 28 Python
python 将字符串中的数字相加求和的实现
Jul 18 Python
Django文件存储 自己定制存储系统解析
Aug 02 Python
Django Docker容器化部署之Django-Docker本地部署
Oct 09 Python
python读取dicom图像示例(SimpleITK和dicom包实现)
Jan 16 Python
浅析python中的del用法
Sep 02 Python
Python使用xpath实现图片爬取
Sep 16 Python
Python机器学习之底层实现KNN
Jun 20 Python
Keras—embedding嵌入层的用法详解
Jun 10 #Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 #Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 #Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 #Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 #Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 #Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 #Python
You might like
德劲1107的电路分析与打磨
2021/03/02 无线电
PHP 存储文本换行实现方法
2010/01/05 PHP
编写PHP脚本清除WordPress头部冗余代码的方法讲解
2016/03/01 PHP
JQuery上传插件Uploadify使用详解及错误处理
2010/04/27 Javascript
一款Jquery 分页插件的改造方法(服务器端分页)
2011/07/11 Javascript
js计算精度问题小结
2013/04/22 Javascript
jquery 跳到顶部和底部动画2句代码简单实现
2013/07/18 Javascript
图标线性回归斜着移动到指定的位置
2013/08/16 Javascript
js call方法详细介绍(js 的继承)
2013/11/18 Javascript
FireBug 调试JS入门教程 如何调试JS
2013/12/23 Javascript
jquery 无限级下拉菜单的简单实现代码
2014/02/21 Javascript
jquery scroll()区分横向纵向滚动条的方法
2014/04/04 Javascript
javascript实现youku的视频代码自适应宽度
2015/05/25 Javascript
Node.js数据库操作之连接MySQL数据库(一)
2017/03/04 Javascript
js清除浏览器缓存的几种方法
2017/03/15 Javascript
微信小程序之前台循环数据绑定
2017/08/18 Javascript
vue父子组件的嵌套的示例代码
2017/09/08 Javascript
使用jQuery mobile NuGet让你的网站在移动设备上同样精彩
2019/06/18 jQuery
如何使用webpack打包一个库library的方法步骤
2019/12/18 Javascript
vue-drag-chart 拖动/缩放图表组件的实例代码
2020/04/10 Javascript
js实现九宫格布局效果
2020/05/28 Javascript
JS forEach跳出循环2种实现方法
2020/06/24 Javascript
浅谈python和C语言混编的几种方式(推荐)
2017/09/27 Python
Python实现发送与接收邮件的方法详解
2018/03/28 Python
django中related_name的用法说明
2020/05/20 Python
Silk’n激光脱毛器官网:silkn.com
2016/10/06 全球购物
党员培训思想汇报
2014/01/07 职场文书
农民工工资支付承诺函
2014/03/31 职场文书
兽医医药专业求职信
2014/07/27 职场文书
公司员工活动策划方案
2014/08/20 职场文书
2015年妇女工作总结
2015/05/14 职场文书
2015年教研室工作总结范文
2015/05/23 职场文书
运动会报道稿大全
2015/07/23 职场文书
pytest进阶教程之fixture函数详解
2021/03/29 Python
彻底理解golang中什么是nil
2021/04/29 Golang
详解CSS故障艺术
2021/05/25 HTML / CSS