TensorFlow中如何确定张量的形状实例


Posted in Python onJune 23, 2020

我们可以使用tf.shape()获取某张量的形状张量。

import tensorflow as tf
x = tf.reshape(tf.range(1000), [10, 10, 10])
sess = tf.Session()
sess.run(tf.shape(x))
 
Out[1]: array([10, 10, 10])

我们可以使用tf.shape()在计算图中确定改变张量的形状。

high = tf.shape(x)[0] // 2
width = tf.shape(x)[1] * 2
x_reshape = tf.reshape(x, [high, width, -1])
sess.run(tf.shape(x_reshape))
 
Out: array([ 5, 20, 10])

我们可以使用tf.shape_n()在计算图中得到若干个张量的形状。

y = tf.reshape(tf.range(504), [7,8,9])
sess.run(tf.shape_n([x, y]))
 
Out: [array([10, 10, 10]), array([7, 8, 9])]

我们可以使用tf.size()获取张量的元素个数。

sess.run([tf.size(x), tf.size(y)])

Out: [1000, 504]

tensor.get_shape()或者tensor.shape是无法在计算图中用于确定张量的形状。

In [20]: x.get_shape()
Out[20]: TensorShape([Dimension(10), Dimension(10), Dimension(10)])
 
In [21]: x.get_shape()[0]
Out[21]: Dimension(10)
 
In [22]: type(x.get_shape()[0])
Out[22]: tensorflow.python.framework.tensor_shape.Dimension
 
In [23]: x.get_shape()
Out[23]: TensorShape([Dimension(10), Dimension(10), Dimension(10)])
 
In [24]: sess.run(x.get_shape())
---------------------------------------------------------------------------
TypeError     Traceback (most recent call last)
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches, contraction_fn)
 299  self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 300  fetch, allow_tensor=True, allow_operation=True))
 301 except TypeError as e:
 
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
 3477 with self._lock:
-> 3478 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
 3479
 
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
 3566 raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
-> 3567        types_str))
 3568
 
TypeError: Can not convert a TensorShapeV1 into a Tensor or Operation.
 
During handling of the above exception, another exception occurred:
 
TypeError     Traceback (most recent call last)
<ipython-input-24-de007c69e003> in <module>
----> 1 sess.run(x.get_shape())
 
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
 927 try:
 928 result = self._run(None, fetches, feed_dict, options_ptr,
--> 929    run_metadata_ptr)
 930 if run_metadata:
 931  proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
 
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
 1135 # Create a fetch handler to take care of the structure of fetches.
 1136 fetch_handler = _FetchHandler(
-> 1137  self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
 1138
 1139 # Run request and get response.
 
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in __init__(self, graph, fetches, feeds, feed_handles)
 469 """
 470 with graph.as_default():
--> 471 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
 472 self._fetches = []
 473 self._targets = []
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in for_fetch(fetch)
 269  if isinstance(fetch, tensor_type):
 270  fetches, contraction_fn = fetch_fn(fetch)
--> 271  return _ElementFetchMapper(fetches, contraction_fn)
 272 # Did not find anything.
 273 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches, contraction_fn)
 302  raise TypeError('Fetch argument %r has invalid type %r, '
 303    'must be a string or Tensor. (%s)' %
--> 304    (fetch, type(fetch), str(e)))
 305 except ValueError as e:
 306  raise ValueError('Fetch argument %r cannot be interpreted as a '
TypeError: Fetch argument TensorShape([Dimension(10), Dimension(10), Dimension(10)]) has invalid type <class 'tensorflow.python.framework.tensor_shape.TensorShapeV1'>, must be a string or Tensor. (Can not convert a TensorShapeV1 into a Tensor or Operation.)

我们可以使用tf.rank()来确定张量的秩。tf.rank()会返回一个代表张量秩的张量,可直接在计算图中使用。

In [25]: tf.rank(x)
Out[25]: <tf.Tensor 'Rank:0' shape=() dtype=int32>
 
In [26]: sess.run(tf.rank(x))
Out[26]: 3

补充知识:tensorflow循环改变tensor的值

使用tf.concat()实现4维tensor的循环赋值

alist=[[[[1,1,1],[2,2,2],[3,3,3]],[[4,4,4],[5,5,5],[6,6,6]]],[[[7,7,7],[8,8,8],[9,9,9]],[[10,10,10],[11,11,11],[12,12,12]]]] #2,2,3,3-n,c,h,w
kenel=(np.asarray(alist)*2).tolist()
print(kenel)
inputs=tf.constant(alist,dtype=tf.float32)
kenel=tf.constant(kenel,dtype=tf.float32)
inputs=tf.transpose(inputs,[0,2,3,1]) #n,h,w,c
kenel=tf.transpose(kenel,[0,2,3,1]) #n,h,w,c
uints=inputs.get_shape()
h=int(uints[1])
w=int(uints[2])
encoder_output=[]
for b in range(int(uints[0])):
 encoder_output_c=[]
 for c in range(int(uints[-1])):
  one_channel_in = inputs[b, :, :, c]
  one_channel_in = tf.reshape(one_channel_in, [1, h, w, 1])
  one_channel_kernel = kenel[b, :, :, c]
  one_channel_kernel = tf.reshape(one_channel_kernel, [h, w, 1, 1])
  encoder_output_cc = tf.nn.conv2d(input=one_channel_in, filter=one_channel_kernel, strides=[1, 1, 1, 1], padding="SAME")
  if c==0:
   encoder_output_c=encoder_output_cc
  else:
   encoder_output_c=tf.concat([encoder_output_c,encoder_output_cc],axis=3)

 if b==0:
  encoder_output=encoder_output_c
 else:
  encoder_output = tf.concat([encoder_output, encoder_output_c], axis=0)

with tf.Session() as sess:
 print(sess.run(tf.transpose(encoder_output,[0,3,1,2])))
 print(encoder_output.get_shape())

输出:

[[[[ 32. 48. 32.]
 [ 56. 84. 56.]
 [ 32. 48. 32.]]

 [[ 200. 300. 200.]
 [ 308. 462. 308.]
 [ 200. 300. 200.]]]


 [[[ 512. 768. 512.]
 [ 776. 1164. 776.]
 [ 512. 768. 512.]]

 [[ 968. 1452. 968.]
 [1460. 2190. 1460.]
 [ 968. 1452. 968.]]]]
(2, 3, 3, 2)

以上这篇TensorFlow中如何确定张量的形状实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的ORM框架SQLObject入门实例
Apr 28 Python
Python中使用 Selenium 实现网页截图实例
Jul 18 Python
MAC中PyCharm设置python3解释器
Dec 15 Python
python http基本验证方法
Dec 26 Python
python+opencv打开摄像头,保存视频、拍照功能的实现方法
Jan 08 Python
Python微医挂号网医生数据抓取
Jan 24 Python
Python控制键盘鼠标pynput的详细用法
Jan 28 Python
Python (Win)readline和tab补全的安装方法
Aug 27 Python
解析python实现Lasso回归
Sep 11 Python
pygame实现俄罗斯方块游戏(AI篇1)
Oct 29 Python
容易被忽略的Python内置类型
Sep 03 Python
pytorch 实现在测试的时候启用dropout
May 27 Python
Python docutils文档编译过程方法解析
Jun 23 #Python
python3的pip路径在哪
Jun 23 #Python
Python错误的处理方法
Jun 23 #Python
python文件读取失败怎么处理
Jun 23 #Python
使用tensorflow根据输入更改tensor shape
Jun 23 #Python
pytorch 计算ConvTranspose1d输出特征大小方式
Jun 23 #Python
Android Q之气泡弹窗的实现示例
Jun 23 #Python
You might like
mysql5详细安装教程
2007/01/15 PHP
PHP 递归效率分析
2009/11/24 PHP
PHP 杂谈《重构-改善既有代码的设计》之五 简化函数调用
2012/05/07 PHP
php使用curl详细解析及问题汇总
2016/08/11 PHP
phpStudy 2016 使用教程详解(支持PHP7)
2017/10/18 PHP
在Z-Blog中运行代码[html][/html](纯JS版)
2007/03/25 Javascript
jQuery学习笔记之jQuery的DOM操作
2010/12/22 Javascript
js实现带搜索功能的下拉框实时搜索实时匹配
2013/11/05 Javascript
为Javascript中的String对象添加去除左右空格的方法(示例代码)
2013/11/30 Javascript
js 处理数组重复元素示例代码
2013/12/27 Javascript
js字符串转换成数字与数字转换成字符串的实现方法
2014/01/08 Javascript
JavaScript中textRange对象使用方法小结
2015/03/24 Javascript
javascript获取重复次数最多的字符
2015/07/08 Javascript
jQuery获取剪贴板内容的方法
2016/06/16 Javascript
vuejs在解析时出现闪烁的原因及防止闪烁的方法
2016/09/19 Javascript
详解react组件通讯方式(多种)
2020/05/06 Javascript
node.js爬虫框架node-crawler初体验
2020/10/29 Javascript
Python正则抓取新闻标题和链接的方法示例
2017/04/24 Python
Python 通配符删除文件的实例
2018/04/24 Python
Windows下将Python文件打包成.EXE可执行文件的方法
2018/08/03 Python
Python函数的默认参数设计示例详解
2019/12/01 Python
h5页面背景图很长要有滚动条滑动效果的实现
2021/01/27 HTML / CSS
phpquery中文手册
2021/03/18 PHP
美国紧身牛仔裤品牌:NYDJ
2017/05/24 全球购物
英国家庭和商业健身器材购物网站:Fitness Options
2018/07/05 全球购物
澳大利亚床上用品、浴巾和家居用品购物网站:Bambury
2020/04/16 全球购物
静态成员和非静态成员的区别
2012/05/12 面试题
机关出纳岗位职责
2014/04/03 职场文书
公司年终奖分配方案
2014/06/16 职场文书
师德师风剖析材料
2014/09/30 职场文书
职代会闭幕词
2015/01/28 职场文书
公务员年度个人总结
2015/02/12 职场文书
护士工作心得体会
2016/01/25 职场文书
SQL Server表分区删除详情
2021/10/16 SQL Server
前端JavaScript大管家 package.json
2021/11/02 Javascript
mysql insert 存在即不插入语法说明
2022/03/25 MySQL