PyTorch中Tensor的维度变换实现


Posted in Python onAugust 18, 2019

对于 PyTorch 的基本数据对象 Tensor (张量),在处理问题时,需要经常改变数据的维度,以便于后期的计算和进一步处理,本文旨在列举一些维度变换的方法并举例,方便大家查看。

维度查看:torch.Tensor.size()

查看当前 tensor 的维度

举个例子:

>>> import torch
>>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]])
>>> a.size()
torch.Size([1, 3, 2])

张量变形:torch.Tensor.view(*args) → Tensor

返回一个有相同数据但大小不同的 tensor。 返回的 tensor 必须有与原 tensor 相同的数据和相同数目的元素,但可以有不同的大小。一个 tensor 必须是连续的 contiguous() 才能被查看。

举个例子:

>>> x = torch.randn(2, 9)
>>> x.size()
torch.Size([2, 9])
>>> x
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038, 0.5166, 0.9774,
     0.3455],
    [-0.2306, 0.4217, 1.2874, -0.3618, 1.7872, -0.9012, 0.8073, -1.1238,
     -0.3405]])
>>> y = x.view(3, 6)
>>> y.size()
torch.Size([3, 6])
>>> y
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038],
    [ 0.5166, 0.9774, 0.3455, -0.2306, 0.4217, 1.2874],
    [-0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]])
>>> z = x.view(2, 3, 3)
>>> z.size()
torch.Size([2, 3, 3])
>>> z
tensor([[[-1.6833, -0.4100, -1.5534],
     [-0.6229, -1.0310, -0.8038],
     [ 0.5166, 0.9774, 0.3455]],

    [[-0.2306, 0.4217, 1.2874],
     [-0.3618, 1.7872, -0.9012],
     [ 0.8073, -1.1238, -0.3405]]])

可以看到 x 和 y 、z 中数据的数量和每个数据的大小都是相等的,只是尺寸或维度数量发生了改变。

压缩 / 解压张量:torch.squeeze()、torch.unsqueeze()

  • torch.squeeze(input, dim=None, out=None)

将输入张量形状中的 1 去除并返回。如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

当给定 dim 时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B),squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

举个例子:

>>> x = torch.randn(3, 1, 2)
>>> x
tensor([[[-0.1986, 0.4352]],

    [[ 0.0971, 0.2296]],

    [[ 0.8339, -0.5433]]])
>>> x.squeeze().size() # 不加参数,去掉所有为元素个数为1的维度
torch.Size([3, 2])
>>> x.squeeze()
tensor([[-0.1986, 0.4352],
    [ 0.0971, 0.2296],
    [ 0.8339, -0.5433]])
>>> torch.squeeze(x, 0).size() # 加上参数,去掉第一维的元素,不起作用,因为第一维有2个元素
torch.Size([3, 1, 2])
>>> torch.squeeze(x, 1).size() # 加上参数,去掉第二维的元素,正好为 1,起作用
torch.Size([3, 2])

可以看到如果加参数,只有维度中尺寸为 1 的位置才会消失

  • torch.unsqueeze(input, dim, out=None)

返回一个新的张量,对输入的制定位置插入维度 1

返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

如果 dim 为负,则将会被转化 dim+input.dim()+1

接着用上面的数据举个例子:

>>> x.unsqueeze(0).size()
torch.Size([1, 3, 1, 2])
>>> x.unsqueeze(0)
tensor([[[[-0.1986, 0.4352]],

     [[ 0.0971, 0.2296]],

     [[ 0.8339, -0.5433]]]])
>>> x.unsqueeze(-1).size()
torch.Size([3, 1, 2, 1])
>>> x.unsqueeze(-1)
tensor([[[[-0.1986],
     [ 0.4352]]],


    [[[ 0.0971],
     [ 0.2296]]],


    [[[ 0.8339],
     [-0.5433]]]])

可以看到在指定的位置,增加了一个维度。

扩大张量:torch.Tensor.expand(*sizes) → Tensor

返回 tensor 的一个新视图,单个维度扩大为更大的尺寸。 tensor 也可以扩大为更高维,新增加的维度将附在前面。 扩大 tensor 不需要分配新内存,只是仅仅新建一个 tensor 的视图,其中通过将 stride 设为 0,一维将会扩展位更高维。任何一个一维的在不分配新内存情况下可扩展为任意的数值。

举个例子:

>>> x = torch.Tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]])
>>> x.expand(3, -1)
tensor([[1.],
    [2.],
    [3.]])

原数据是 3 行 1 列,扩大后变为 3 行 4 列,方法中填 -1 的效果与 1 一样,只有尺寸为 1 才可以扩大,如果不为 1 就无法改变,而且尺寸不为 1 的维度必须要和原来一样填写进去。

重复张量:torch.Tensor.repeat(*sizes)

沿着指定的维度重复 tensor。 不同于 expand(),本函数复制的是 tensor 中的数据。

举个例子:

>>> x = torch.Tensor([1, 2, 3])
>>> x.size()
torch.Size([3])
>>> x.repeat(4, 2)
    [1., 2., 3., 1., 2., 3.],
    [1., 2., 3., 1., 2., 3.],
    [1., 2., 3., 1., 2., 3.]])
>>> x.repeat(4, 2).size()
torch.Size([4, 6])

原数据为 1 行 3 列,按行方向扩大为原来的 4 倍,列方向扩大为原来的 2 倍,变为了 4 行 6 列。

变化时可以看成是把原数据作成一个整体,再按指定的维度和尺寸重复,变成一个 4 行 2 列的矩阵,其中的每一个单位都是相同的,再把原数据放到每个单位中。

矩阵转置:torch.t(input, out=None) → Tensor

输入一个矩阵(2维张量),并转置0, 1维。 可以被视为函数 transpose(input, 0, 1) 的简写函数。

举个例子:

>>> x = torch.randn(3, 5)
>>> x
tensor([[-1.0752, -0.9706, -0.8770, -0.4224, 0.9776],
    [ 0.2489, -0.2986, -0.7816, -0.0823, 1.1811],
    [-1.1124, 0.2160, -0.8446, 0.1762, -0.5164]])
>>> x.t()
tensor([[-1.0752, 0.2489, -1.1124],
    [-0.9706, -0.2986, 0.2160],
    [-0.8770, -0.7816, -0.8446],
    [-0.4224, -0.0823, 0.1762],
    [ 0.9776, 1.1811, -0.5164]])
>>> torch.t(x) # 另一种用法
tensor([[-1.0752, 0.2489, -1.1124],
    [-0.9706, -0.2986, 0.2160],
    [-0.8770, -0.7816, -0.8446],
    [-0.4224, -0.0823, 0.1762],
    [ 0.9776, 1.1811, -0.5164]])

必须要是 2 维的张量,也就是矩阵,才可以使用。

维度置换:torch.transpose()、torch.Tensor.permute()

  • torch.transpose(input, dim0, dim1, out=None) → Tensor

返回输入矩阵 input 的转置。交换维度 dim0 和 dim1。 输出张量与输入张量共享内存,所以改变其中一个会导致另外一个也被修改。

举个例子:

>>> x = torch.randn(2, 4, 3)
>>> x
tensor([[[-1.2502, -0.7363, 0.5534],
     [-0.2050, 3.1847, -1.6729],
     [-0.2591, -0.0860, 0.4660],
     [-1.2189, -1.1206, 0.0637]],

    [[ 1.4791, -0.7569, 2.5017],
     [ 0.0098, -1.0217, 0.8142],
     [-0.2414, -0.1790, 2.3506],
     [-0.6860, -0.2363, 1.0481]]])
>>> torch.transpose(x, 1, 2).size()
torch.Size([2, 3, 4])
>>> torch.transpose(x, 1, 2)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],
     [-0.7363, 3.1847, -0.0860, -1.1206],
     [ 0.5534, -1.6729, 0.4660, 0.0637]],

    [[ 1.4791, 0.0098, -0.2414, -0.6860],
     [-0.7569, -1.0217, -0.1790, -0.2363],
     [ 2.5017, 0.8142, 2.3506, 1.0481]]])
>>> torch.transpose(x, 0, 1).size()
torch.Size([4, 2, 3])
>>> torch.transpose(x, 0, 1)
tensor([[[-1.2502, -0.7363, 0.5534],
     [ 1.4791, -0.7569, 2.5017]],

    [[-0.2050, 3.1847, -1.6729],
     [ 0.0098, -1.0217, 0.8142]],

    [[-0.2591, -0.0860, 0.4660],
     [-0.2414, -0.1790, 2.3506]],

    [[-1.2189, -1.1206, 0.0637],
     [-0.6860, -0.2363, 1.0481]]])

可以对多维度的张量进行转置

  • torch.Tensor.permute(dims)

将 tensor 的维度换位

接着用上面的数据举个例子:

>>> x.size()
torch.Size([2, 4, 3])
>>> x.permute(2, 0, 1).size()
torch.Size([3, 2, 4])
>>> x.permute(2, 0, 1)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],
     [ 1.4791, 0.0098, -0.2414, -0.6860]],

    [[-0.7363, 3.1847, -0.0860, -1.1206],
     [-0.7569, -1.0217, -0.1790, -0.2363]],

    [[ 0.5534, -1.6729, 0.4660, 0.0637],
     [ 2.5017, 0.8142, 2.3506, 1.0481]]])

直接在方法中填入各个维度的索引,张量就会交换指定维度的尺寸,不限于两两交换。

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的测试模块unittest和doctest的使用教程
Apr 14 Python
Python 中的with关键字使用详解
Sep 11 Python
python实现发送邮件及附件功能
Mar 02 Python
python实现在pandas.DataFrame添加一行
Apr 04 Python
python随机取list中的元素方法
Apr 08 Python
python matplotlib 在指定的两个点之间连线方法
May 25 Python
在Python中分别打印列表中的每一个元素方法
Nov 07 Python
Windows平台Python编程必会模块之pywin32介绍
Oct 01 Python
python安装后的目录在哪里
Jun 21 Python
Python自定义sorted排序实现方法详解
Sep 18 Python
python 获取剪切板内容的两种方法
Nov 28 Python
python os.listdir()乱码解决方案
Jan 31 Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
You might like
使用php统计字符串中中英文字符的个数
2013/06/23 PHP
php接口和抽象类使用示例详解
2014/03/02 PHP
Laravel如何使用Redis共享Session
2018/02/23 PHP
nullJavascript中创建对象的五种方法实例
2013/05/07 Javascript
document节点对象的获取方式示例介绍
2013/12/24 Javascript
JavaScript 消息框效果【实现代码】
2016/04/27 Javascript
jQuery的extend方法【三种】
2016/12/14 Javascript
使用JavaScript开发跨平台的桌面应用详解
2017/07/27 Javascript
ES6顶层对象、global对象实例分析
2019/06/14 Javascript
element-ui 远程搜索组件el-select在项目中组件化的实现代码
2019/12/04 Javascript
简单了解JS打开url的方法
2020/02/21 Javascript
vue+elementUI 实现内容区域高度自适应的示例
2020/09/26 Javascript
JavaScript 声明私有变量的两种方式
2021/02/05 Javascript
python使用socket远程连接错误处理方法
2015/04/29 Python
Python同时向控制台和文件输出日志logging的方法
2015/05/26 Python
Python使用email模块对邮件进行编码和解码的实例教程
2016/07/01 Python
基于Python代码编辑器的选用(详解)
2017/09/13 Python
安装好Pycharm后如何配置Python解释器简易教程
2019/06/28 Python
Python定义函数时参数有默认值问题解决
2019/12/19 Python
如何在windows下安装Pycham2020软件(方法步骤详解)
2020/05/03 Python
学python需要去培训机构吗
2020/07/01 Python
python如何快速拼接字符串
2020/10/28 Python
新加坡领先的时尚生活方式零售品牌:CHARLES & KEITH
2018/01/16 全球购物
Mankind美国/加拿大:英国领先的男士美容护发用品公司
2018/12/05 全球购物
亲戚结婚的请假条
2014/02/11 职场文书
2014年公司迎新年活动方案
2014/02/24 职场文书
品牌转让协议书
2014/08/20 职场文书
作风大整顿心得体会
2014/09/10 职场文书
毕业生实习证明
2014/09/19 职场文书
党支部四风整改方案
2014/10/25 职场文书
教代会闭幕词
2015/01/28 职场文书
2015年质检工作总结
2015/05/04 职场文书
转学证明范本
2015/06/19 职场文书
篮球赛闭幕式主持词
2015/07/03 职场文书
2016入党积极分子心得体会
2016/01/06 职场文书
MySQL数据管理操作示例讲解
2022/12/24 MySQL