在keras 中获取张量 tensor 的维度大小实例


Posted in Python onJune 10, 2020

在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。

shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,

示例:

>>> from keras import backend as K
>>> tf_session = K.get_session()
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> input = keras.backend.placeholder(shape=(2, 4, 5))
>>> K.shape(kvar)
<tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
>>> K.shape(input)
<tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
__To get integer shape (Instead, you can use K.int_shape(x))__
 
>>> K.shape(kvar).eval(session=tf_session)
array([2, 2], dtype=int32)
>>> K.shape(input).eval(session=tf_session)
array([2, 4, 5], dtype=int32)

如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。

>>> from keras import backend as K
>>> input = K.placeholder(shape=(2, 4, 5))
>>> K.int_shape(input)
(2, 4, 5)
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.int_shape(kvar)
(2, 2)

最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。

补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)

tf.shape(a)和a.get_shape()比较

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

import tensorflow as tf 
import numpy as np 

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]] 
z=np.arange(24).reshape([2,3,4])

sess=tf.Session() 
# tf.shape() 
x_shape=tf.shape(x)          # x_shape 是一个tensor 
y_shape=tf.shape(y)          # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> 
z_shape=tf.shape(z)          # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> 
print(sess.run(x_shape))       # 结果:[2 3]
print(sess.run(y_shape))       # 结果:[2 3]
print(sess.run(z_shape) )       # 结果:[2 3 4]

x_shape=x.get_shape() 
print(x_shape)    # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组                            (2, 3)
x_shape=x.get_shape().as_list() 
print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算
# y_shape=y.get_shape() 
print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() 
print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()

以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python删除文件示例分享
Jan 28 Python
python根据开头和结尾字符串获取中间字符串的方法
Mar 26 Python
Python实现字符串格式化的方法小结
Feb 20 Python
numpy concatenate数组拼接方法示例介绍
May 27 Python
Python学习笔记基本数据结构之序列类型list tuple range用法分析
Jun 08 Python
Python的条件锁与事件共享详解
Sep 12 Python
python实现简单成绩录入系统
Sep 19 Python
python如何通过闭包实现计算器的功能
Feb 22 Python
Python Websocket服务端通信的使用示例
Feb 25 Python
keras中的History对象用法
Jun 19 Python
cookies应对python反爬虫知识点详解
Nov 25 Python
Python中的min及返回最小值索引的操作
May 10 Python
Keras—embedding嵌入层的用法详解
Jun 10 #Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 #Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 #Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 #Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 #Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 #Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 #Python
You might like
Zend的AutoLoad机制介绍
2012/09/27 PHP
PHP中使用sleep函数实现定时任务实例分享
2014/08/21 PHP
php数组操作之键名比较与差集、交集赋值的方法
2014/11/10 PHP
php下foreach提示Warning:Invalid argument supplied for foreach()的解决方法
2014/11/11 PHP
Yii2.0高级框架数据库增删改查的一些操作
2015/11/16 PHP
round robin权重轮循算法php实现代码
2016/05/28 PHP
JavaScript获取GridView中用户点击控件的行号,列号
2009/04/14 Javascript
JQuery 写的个性导航菜单
2009/12/24 Javascript
javascript中怎么做对象的类型判断
2013/11/11 Javascript
jquery实现树形二级菜单实例代码
2013/11/20 Javascript
httpclient模拟登陆具体实现(使用js设置cookie)
2013/12/11 Javascript
加载列表时jquery获取ul中第一个li的属性
2014/11/02 Javascript
学习JavaScript设计模式之装饰者模式
2016/01/19 Javascript
Vue框架中正确引入JS库的方法介绍
2017/07/30 Javascript
nodejs之koa2请求示例(GET,POST)
2018/08/07 NodeJs
jQuery中使用validate插件校验表单功能
2019/05/24 jQuery
浅入深出Vue之组件使用
2019/07/11 Javascript
[01:00:44]DOTA2上海特级锦标赛主赛事日 - 3 败者组第三轮#1COL VS Alliance第三局
2016/03/04 DOTA
[01:01:24]LGD vs Fnatic 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
pyv8学习python和javascript变量进行交互
2013/12/04 Python
深入Python函数编程的一些特性
2015/04/13 Python
在Linux下调试Python代码的各种方法
2015/04/17 Python
Python实现堆排序的方法详解
2016/05/03 Python
python计算auc指标实例
2017/07/13 Python
Python中用psycopg2模块操作PostgreSQL方法
2017/11/28 Python
Python+matplotlib+numpy绘制精美的条形统计图
2018/01/02 Python
python实现flappy bird小游戏
2018/12/24 Python
Python基于机器学习方法实现的电影推荐系统实例详解
2019/06/25 Python
使用anaconda安装pytorch的实现步骤
2020/09/03 Python
学习方法演讲稿
2014/05/10 职场文书
公司员工安全协议书
2014/11/21 职场文书
护士实习自荐信
2015/03/06 职场文书
2015年小学体育工作总结
2015/05/22 职场文书
机械原理课程设计心得体会
2016/01/15 职场文书
2021-4-5课程——SQL Server查询【3】
2021/04/05 SQL Server
Axios代理配置及封装响应拦截处理方式
2022/04/07 Vue.js