在keras 中获取张量 tensor 的维度大小实例


Posted in Python onJune 10, 2020

在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。

shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,

示例:

>>> from keras import backend as K
>>> tf_session = K.get_session()
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> input = keras.backend.placeholder(shape=(2, 4, 5))
>>> K.shape(kvar)
<tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
>>> K.shape(input)
<tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
__To get integer shape (Instead, you can use K.int_shape(x))__
 
>>> K.shape(kvar).eval(session=tf_session)
array([2, 2], dtype=int32)
>>> K.shape(input).eval(session=tf_session)
array([2, 4, 5], dtype=int32)

如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。

>>> from keras import backend as K
>>> input = K.placeholder(shape=(2, 4, 5))
>>> K.int_shape(input)
(2, 4, 5)
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.int_shape(kvar)
(2, 2)

最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。

补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)

tf.shape(a)和a.get_shape()比较

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

import tensorflow as tf 
import numpy as np 

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]] 
z=np.arange(24).reshape([2,3,4])

sess=tf.Session() 
# tf.shape() 
x_shape=tf.shape(x)          # x_shape 是一个tensor 
y_shape=tf.shape(y)          # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> 
z_shape=tf.shape(z)          # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> 
print(sess.run(x_shape))       # 结果:[2 3]
print(sess.run(y_shape))       # 结果:[2 3]
print(sess.run(z_shape) )       # 结果:[2 3 4]

x_shape=x.get_shape() 
print(x_shape)    # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组                            (2, 3)
x_shape=x.get_shape().as_list() 
print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算
# y_shape=y.get_shape() 
print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() 
print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()

以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
二种python发送邮件实例讲解(python发邮件附件可以使用email模块实现)
Dec 03 Python
Python使用scrapy采集数据时为每个请求随机分配user-agent的方法
Apr 08 Python
Python Matplotlib库入门指南
May 18 Python
Python编程实现的简单Web服务器示例
Jun 22 Python
python opencv之分水岭算法示例
Feb 24 Python
python高级特性和高阶函数及使用详解
Oct 17 Python
python获取交互式ssh shell的方法
Feb 14 Python
Python3.4学习笔记之 idle 清屏扩展插件用法分析
Mar 01 Python
解决在pycharm运行代码,调用CMD窗口的命令运行显示乱码问题
Aug 23 Python
python GUI库图形界面开发之PyQt5拖放控件实例详解
Feb 25 Python
Python Django form 组件动态从数据库取choices数据实例
May 19 Python
python获取整个网页源码的方法
Aug 03 Python
Keras—embedding嵌入层的用法详解
Jun 10 #Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 #Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 #Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 #Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 #Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 #Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 #Python
You might like
40年前的这部特摄片恐龙特级克塞号80后的共同回忆
2020/03/08 日漫
php懒人函数 自动添加数据
2011/06/28 PHP
PHP图片处理之使用imagecopyresampled函数裁剪图片例子
2014/11/19 PHP
php实现二进制和文本相互转换的方法
2015/04/18 PHP
详解PHP序列化反序列化的方法
2015/10/27 PHP
PHP实现批量上传单个文件
2015/12/29 PHP
关于跨站脚本攻击问题
2011/12/22 Javascript
深入理解javascript中return的作用
2013/12/30 Javascript
基于Jquery制作图片文字排版预览效果附源码下载
2015/11/18 Javascript
详解Javacript和AngularJS中的Promises
2016/02/09 Javascript
关于网页中的无缝滚动的js代码
2016/06/09 Javascript
详解AngularJS如何实现跨域请求
2016/08/22 Javascript
js实现移动端微信页面禁止字体放大
2017/02/16 Javascript
jQuery插件zTree实现单独选中根节点中第一个节点示例
2017/03/08 Javascript
浅谈React + Webpack 构建打包优化
2018/01/23 Javascript
vue自定义一个v-model的实现代码
2018/06/21 Javascript
基于vue的tab-list类目切换商品列表组件的示例代码
2020/02/14 Javascript
解决echarts中横坐标值显示不全(自动隐藏)问题
2020/07/20 Javascript
video.js添加自定义组件的方法
2020/12/09 Javascript
JavaScript实现瀑布流布局的3种方式
2020/12/27 Javascript
[52:27]2018DOTA2亚洲邀请赛 3.31 小组赛B组 paiN vs Secret
2018/04/01 DOTA
跟老齐学Python之编写类之二方法
2014/10/11 Python
情人节快乐! python绘制漂亮玫瑰
2020/08/18 Python
Python爬虫实现爬取百度百科词条功能实例
2019/04/05 Python
python实现微信防撤回神器
2019/04/29 Python
python之PyQt按钮右键菜单功能的实现代码
2019/08/17 Python
python爬虫爬取幽默笑话网站
2019/10/24 Python
IE下实现类似CSS3 text-shadow文字阴影的几种方法
2011/05/11 HTML / CSS
微信浏览器左上角返回按钮拦截功能
2017/11/21 HTML / CSS
Theory美国官网:后现代都市风时装品牌
2018/05/09 全球购物
颁奖典礼主持词
2014/03/25 职场文书
承诺书的格式范文
2014/03/28 职场文书
教师学习党的群众路线教育实践活动心得体会
2014/10/31 职场文书
校运会加油稿大全
2015/07/22 职场文书
Python实现双向链表
2022/05/25 Python
Java 中的 Lambda List 转 Map 的多种方法详解
2022/07/07 Java/Android