在keras 中获取张量 tensor 的维度大小实例


Posted in Python onJune 10, 2020

在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。

shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,

示例:

>>> from keras import backend as K
>>> tf_session = K.get_session()
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> input = keras.backend.placeholder(shape=(2, 4, 5))
>>> K.shape(kvar)
<tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
>>> K.shape(input)
<tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
__To get integer shape (Instead, you can use K.int_shape(x))__
 
>>> K.shape(kvar).eval(session=tf_session)
array([2, 2], dtype=int32)
>>> K.shape(input).eval(session=tf_session)
array([2, 4, 5], dtype=int32)

如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。

>>> from keras import backend as K
>>> input = K.placeholder(shape=(2, 4, 5))
>>> K.int_shape(input)
(2, 4, 5)
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.int_shape(kvar)
(2, 2)

最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。

补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)

tf.shape(a)和a.get_shape()比较

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

import tensorflow as tf 
import numpy as np 

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]] 
z=np.arange(24).reshape([2,3,4])

sess=tf.Session() 
# tf.shape() 
x_shape=tf.shape(x)          # x_shape 是一个tensor 
y_shape=tf.shape(y)          # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> 
z_shape=tf.shape(z)          # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> 
print(sess.run(x_shape))       # 结果:[2 3]
print(sess.run(y_shape))       # 结果:[2 3]
print(sess.run(z_shape) )       # 结果:[2 3 4]

x_shape=x.get_shape() 
print(x_shape)    # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组                            (2, 3)
x_shape=x.get_shape().as_list() 
print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算
# y_shape=y.get_shape() 
print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() 
print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()

以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中针对函数处理的特殊方法
Mar 06 Python
Python time模块详解(常用函数实例讲解,非常好)
Apr 24 Python
python通过floor函数舍弃小数位的方法
Mar 17 Python
Eclipse和PyDev搭建完美Python开发环境教程(Windows篇)
Nov 16 Python
python实现给微信公众号发送消息的方法
Jun 30 Python
python DataFrame获取行数、列数、索引及第几行第几列的值方法
Apr 08 Python
python+numpy+matplotalib实现梯度下降法
Aug 31 Python
Python 将Matrix、Dict保存到文件的方法
Oct 30 Python
Python获取航线信息并且制作成图的讲解
Jan 03 Python
python tkinter实现屏保程序
Jul 30 Python
Python scrapy增量爬取实例及实现过程解析
Dec 24 Python
Python基于爬虫实现全网搜索并下载音乐
Feb 14 Python
Keras—embedding嵌入层的用法详解
Jun 10 #Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 #Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 #Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 #Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 #Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 #Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 #Python
You might like
PHP中GET变量的使用
2006/10/09 PHP
PHP与MySQL交互使用详解
2006/10/09 PHP
解决Laravel5.5下的toArray问题
2019/10/15 PHP
Javascript 检测、添加、移除样式(className)函数代码
2009/09/08 Javascript
五段实用的js高级技巧
2011/12/20 Javascript
一个简单的jQuery插件ajaxfileupload.js实现ajax上传文件例子
2014/06/26 Javascript
Javascript 绘制 sin 曲线过程附图
2014/08/21 Javascript
浅谈JavaScript正则表达式分组匹配
2015/04/10 Javascript
jQuery删除节点用法示例(remove方法)
2016/09/08 Javascript
JavaScript省市区三级联动菜单效果
2016/09/21 Javascript
JavaScript 控制字体大小设置的方法
2016/11/23 Javascript
原生JS实现图片轮播效果
2016/12/26 Javascript
微信小程序 五星评价功能的实现
2017/03/09 Javascript
dts文件中删除一个node或属性的操作方法
2018/08/05 Javascript
iview实现图片上传功能
2020/06/29 Javascript
在Vue中使用Echarts可视化库的完整步骤记录
2020/11/18 Vue.js
[02:51]DOTA2 2015国际邀请赛中国区预选赛第一日战报
2015/05/27 DOTA
[52:07]完美世界DOTA2联赛PWL S3 LBZS vs access 第二场 12.10
2020/12/13 DOTA
python3.3使用tkinter开发猜数字游戏示例
2014/03/14 Python
python中模块查找的原理与方法详解
2017/08/11 Python
Python访问MongoDB,并且转换成Dataframe的方法
2018/10/15 Python
捷克玩具商店:Bambule
2019/02/23 全球购物
信用社实习人员自我鉴定
2013/09/20 职场文书
研发工程师的岗位职责
2013/11/18 职场文书
心理学专业毕业生推荐信范文
2013/11/21 职场文书
家长给幼儿园的表扬信
2014/01/09 职场文书
上课打牌的检讨书
2014/02/15 职场文书
美术毕业生求职信
2014/02/25 职场文书
兴趣小组活动总结
2014/05/05 职场文书
学院党的群众路线教育实践活动整改方案
2014/10/04 职场文书
医药公司采购员岗位职责
2015/04/03 职场文书
班级元旦晚会开幕词
2016/03/04 职场文书
创业计划书之香辣虾火锅
2019/09/23 职场文书
小学三年级作文之写景
2019/11/05 职场文书
linux目录管理方法介绍
2022/06/01 Servers
CSS中calc(100%-100px)不加空格不生效
2023/05/07 HTML / CSS