TensorFlow中如何确定张量的形状实例


Posted in Python onJune 23, 2020

我们可以使用tf.shape()获取某张量的形状张量。

import tensorflow as tf
x = tf.reshape(tf.range(1000), [10, 10, 10])
sess = tf.Session()
sess.run(tf.shape(x))
 
Out[1]: array([10, 10, 10])

我们可以使用tf.shape()在计算图中确定改变张量的形状。

high = tf.shape(x)[0] // 2
width = tf.shape(x)[1] * 2
x_reshape = tf.reshape(x, [high, width, -1])
sess.run(tf.shape(x_reshape))
 
Out: array([ 5, 20, 10])

我们可以使用tf.shape_n()在计算图中得到若干个张量的形状。

y = tf.reshape(tf.range(504), [7,8,9])
sess.run(tf.shape_n([x, y]))
 
Out: [array([10, 10, 10]), array([7, 8, 9])]

我们可以使用tf.size()获取张量的元素个数。

sess.run([tf.size(x), tf.size(y)])

Out: [1000, 504]

tensor.get_shape()或者tensor.shape是无法在计算图中用于确定张量的形状。

In [20]: x.get_shape()
Out[20]: TensorShape([Dimension(10), Dimension(10), Dimension(10)])
 
In [21]: x.get_shape()[0]
Out[21]: Dimension(10)
 
In [22]: type(x.get_shape()[0])
Out[22]: tensorflow.python.framework.tensor_shape.Dimension
 
In [23]: x.get_shape()
Out[23]: TensorShape([Dimension(10), Dimension(10), Dimension(10)])
 
In [24]: sess.run(x.get_shape())
---------------------------------------------------------------------------
TypeError     Traceback (most recent call last)
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches, contraction_fn)
 299  self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 300  fetch, allow_tensor=True, allow_operation=True))
 301 except TypeError as e:
 
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
 3477 with self._lock:
-> 3478 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
 3479
 
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
 3566 raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
-> 3567        types_str))
 3568
 
TypeError: Can not convert a TensorShapeV1 into a Tensor or Operation.
 
During handling of the above exception, another exception occurred:
 
TypeError     Traceback (most recent call last)
<ipython-input-24-de007c69e003> in <module>
----> 1 sess.run(x.get_shape())
 
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
 927 try:
 928 result = self._run(None, fetches, feed_dict, options_ptr,
--> 929    run_metadata_ptr)
 930 if run_metadata:
 931  proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
 
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
 1135 # Create a fetch handler to take care of the structure of fetches.
 1136 fetch_handler = _FetchHandler(
-> 1137  self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
 1138
 1139 # Run request and get response.
 
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in __init__(self, graph, fetches, feeds, feed_handles)
 469 """
 470 with graph.as_default():
--> 471 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
 472 self._fetches = []
 473 self._targets = []
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in for_fetch(fetch)
 269  if isinstance(fetch, tensor_type):
 270  fetches, contraction_fn = fetch_fn(fetch)
--> 271  return _ElementFetchMapper(fetches, contraction_fn)
 272 # Did not find anything.
 273 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches, contraction_fn)
 302  raise TypeError('Fetch argument %r has invalid type %r, '
 303    'must be a string or Tensor. (%s)' %
--> 304    (fetch, type(fetch), str(e)))
 305 except ValueError as e:
 306  raise ValueError('Fetch argument %r cannot be interpreted as a '
TypeError: Fetch argument TensorShape([Dimension(10), Dimension(10), Dimension(10)]) has invalid type <class 'tensorflow.python.framework.tensor_shape.TensorShapeV1'>, must be a string or Tensor. (Can not convert a TensorShapeV1 into a Tensor or Operation.)

我们可以使用tf.rank()来确定张量的秩。tf.rank()会返回一个代表张量秩的张量,可直接在计算图中使用。

In [25]: tf.rank(x)
Out[25]: <tf.Tensor 'Rank:0' shape=() dtype=int32>
 
In [26]: sess.run(tf.rank(x))
Out[26]: 3

补充知识:tensorflow循环改变tensor的值

使用tf.concat()实现4维tensor的循环赋值

alist=[[[[1,1,1],[2,2,2],[3,3,3]],[[4,4,4],[5,5,5],[6,6,6]]],[[[7,7,7],[8,8,8],[9,9,9]],[[10,10,10],[11,11,11],[12,12,12]]]] #2,2,3,3-n,c,h,w
kenel=(np.asarray(alist)*2).tolist()
print(kenel)
inputs=tf.constant(alist,dtype=tf.float32)
kenel=tf.constant(kenel,dtype=tf.float32)
inputs=tf.transpose(inputs,[0,2,3,1]) #n,h,w,c
kenel=tf.transpose(kenel,[0,2,3,1]) #n,h,w,c
uints=inputs.get_shape()
h=int(uints[1])
w=int(uints[2])
encoder_output=[]
for b in range(int(uints[0])):
 encoder_output_c=[]
 for c in range(int(uints[-1])):
  one_channel_in = inputs[b, :, :, c]
  one_channel_in = tf.reshape(one_channel_in, [1, h, w, 1])
  one_channel_kernel = kenel[b, :, :, c]
  one_channel_kernel = tf.reshape(one_channel_kernel, [h, w, 1, 1])
  encoder_output_cc = tf.nn.conv2d(input=one_channel_in, filter=one_channel_kernel, strides=[1, 1, 1, 1], padding="SAME")
  if c==0:
   encoder_output_c=encoder_output_cc
  else:
   encoder_output_c=tf.concat([encoder_output_c,encoder_output_cc],axis=3)

 if b==0:
  encoder_output=encoder_output_c
 else:
  encoder_output = tf.concat([encoder_output, encoder_output_c], axis=0)

with tf.Session() as sess:
 print(sess.run(tf.transpose(encoder_output,[0,3,1,2])))
 print(encoder_output.get_shape())

输出:

[[[[ 32. 48. 32.]
 [ 56. 84. 56.]
 [ 32. 48. 32.]]

 [[ 200. 300. 200.]
 [ 308. 462. 308.]
 [ 200. 300. 200.]]]


 [[[ 512. 768. 512.]
 [ 776. 1164. 776.]
 [ 512. 768. 512.]]

 [[ 968. 1452. 968.]
 [1460. 2190. 1460.]
 [ 968. 1452. 968.]]]]
(2, 3, 3, 2)

以上这篇TensorFlow中如何确定张量的形状实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现找出数组中第2大数字的方法示例
Mar 26 Python
Python 实现子类获取父类的类成员方法
Jan 11 Python
python+selenium 定位到元素,无法点击的解决方法
Jan 30 Python
Python实现插入排序和选择排序的方法
May 12 Python
python 画出使用分类器得到的决策边界
Aug 21 Python
python实发邮件实例详解
Nov 11 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 Python
OpenCV Python实现拼图小游戏
Mar 23 Python
Django ORM 查询表中某列字段值的方法
Apr 30 Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 Python
Django Auth用户认证组件实现代码
Oct 13 Python
Python如何加载模型并查看网络
Jul 15 Python
Python docutils文档编译过程方法解析
Jun 23 #Python
python3的pip路径在哪
Jun 23 #Python
Python错误的处理方法
Jun 23 #Python
python文件读取失败怎么处理
Jun 23 #Python
使用tensorflow根据输入更改tensor shape
Jun 23 #Python
pytorch 计算ConvTranspose1d输出特征大小方式
Jun 23 #Python
Android Q之气泡弹窗的实现示例
Jun 23 #Python
You might like
PHP记录页面停留时间的方法
2016/03/30 PHP
PHP反射API示例分享
2016/10/08 PHP
Laravel 手动开关 Eloquent 修改器的操作方法
2019/12/30 PHP
PHP 枚举类型的管理与设计知识点总结
2020/02/13 PHP
JavaScript 指导方针
2007/04/05 Javascript
利用js跨页面保存变量做菜单的方法
2008/01/17 Javascript
jquery 简短几句代码实现给元素动态添加及获取提示信息
2011/09/01 Javascript
JavaScript通过元素的ID和name设置样式
2014/07/08 Javascript
jQuery实现简单的图片查看器
2020/09/11 Javascript
jQuery自动完成插件completer附源码下载
2016/01/04 Javascript
JS读取XML文件数据并以table形式显示数据的方法(兼容IE与火狐)
2016/06/02 Javascript
AngularJS使用ng-options指令实现下拉框
2016/08/23 Javascript
JS闭包与延迟求值用法示例
2016/12/22 Javascript
vue开发调试神器vue-devtools使用详解
2017/07/13 Javascript
vue2.0在table中实现全选和反选的示例代码
2017/11/04 Javascript
vue一个页面实现音乐播放器的示例
2018/02/06 Javascript
vue-cli脚手架-bulid下的配置文件
2018/03/27 Javascript
js设置鼠标悬停改变背景色实现详解
2019/06/26 Javascript
JS实现简易留言板特效
2019/12/23 Javascript
vue2.0 解决抽取公用js的问题
2020/07/31 Javascript
python快速排序代码实例
2013/11/21 Python
python双向链表实现实例代码
2013/11/21 Python
python读取LMDB中图像的方法
2018/07/02 Python
用Python读取几十万行文本数据
2018/12/24 Python
调试Django时打印SQL语句的日志代码实例
2019/09/12 Python
python中封包建立过程实例
2021/02/18 Python
KENZO官网:高田贤三在法国创立的品牌
2019/05/16 全球购物
《记金华的双龙洞》教学反思
2014/04/19 职场文书
大气污染防治方案
2014/05/19 职场文书
乡镇食品安全责任书
2014/07/28 职场文书
企业领导对照检查材料
2014/08/20 职场文书
群众路线剖析材料范文
2014/10/09 职场文书
求职简历自我评价范文
2015/03/10 职场文书
2015小学毕业班工作总结
2015/07/21 职场文书
环保主题班会教案
2015/08/13 职场文书
研究生学习计划书应该怎么写?
2019/09/10 职场文书