PyTorch中Tensor的维度变换实现


Posted in Python onAugust 18, 2019

对于 PyTorch 的基本数据对象 Tensor (张量),在处理问题时,需要经常改变数据的维度,以便于后期的计算和进一步处理,本文旨在列举一些维度变换的方法并举例,方便大家查看。

维度查看:torch.Tensor.size()

查看当前 tensor 的维度

举个例子:

>>> import torch
>>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]])
>>> a.size()
torch.Size([1, 3, 2])

张量变形:torch.Tensor.view(*args) → Tensor

返回一个有相同数据但大小不同的 tensor。 返回的 tensor 必须有与原 tensor 相同的数据和相同数目的元素,但可以有不同的大小。一个 tensor 必须是连续的 contiguous() 才能被查看。

举个例子:

>>> x = torch.randn(2, 9)
>>> x.size()
torch.Size([2, 9])
>>> x
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038, 0.5166, 0.9774,
     0.3455],
    [-0.2306, 0.4217, 1.2874, -0.3618, 1.7872, -0.9012, 0.8073, -1.1238,
     -0.3405]])
>>> y = x.view(3, 6)
>>> y.size()
torch.Size([3, 6])
>>> y
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038],
    [ 0.5166, 0.9774, 0.3455, -0.2306, 0.4217, 1.2874],
    [-0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]])
>>> z = x.view(2, 3, 3)
>>> z.size()
torch.Size([2, 3, 3])
>>> z
tensor([[[-1.6833, -0.4100, -1.5534],
     [-0.6229, -1.0310, -0.8038],
     [ 0.5166, 0.9774, 0.3455]],

    [[-0.2306, 0.4217, 1.2874],
     [-0.3618, 1.7872, -0.9012],
     [ 0.8073, -1.1238, -0.3405]]])

可以看到 x 和 y 、z 中数据的数量和每个数据的大小都是相等的,只是尺寸或维度数量发生了改变。

压缩 / 解压张量:torch.squeeze()、torch.unsqueeze()

  • torch.squeeze(input, dim=None, out=None)

将输入张量形状中的 1 去除并返回。如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

当给定 dim 时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B),squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

举个例子:

>>> x = torch.randn(3, 1, 2)
>>> x
tensor([[[-0.1986, 0.4352]],

    [[ 0.0971, 0.2296]],

    [[ 0.8339, -0.5433]]])
>>> x.squeeze().size() # 不加参数,去掉所有为元素个数为1的维度
torch.Size([3, 2])
>>> x.squeeze()
tensor([[-0.1986, 0.4352],
    [ 0.0971, 0.2296],
    [ 0.8339, -0.5433]])
>>> torch.squeeze(x, 0).size() # 加上参数,去掉第一维的元素,不起作用,因为第一维有2个元素
torch.Size([3, 1, 2])
>>> torch.squeeze(x, 1).size() # 加上参数,去掉第二维的元素,正好为 1,起作用
torch.Size([3, 2])

可以看到如果加参数,只有维度中尺寸为 1 的位置才会消失

  • torch.unsqueeze(input, dim, out=None)

返回一个新的张量,对输入的制定位置插入维度 1

返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

如果 dim 为负,则将会被转化 dim+input.dim()+1

接着用上面的数据举个例子:

>>> x.unsqueeze(0).size()
torch.Size([1, 3, 1, 2])
>>> x.unsqueeze(0)
tensor([[[[-0.1986, 0.4352]],

     [[ 0.0971, 0.2296]],

     [[ 0.8339, -0.5433]]]])
>>> x.unsqueeze(-1).size()
torch.Size([3, 1, 2, 1])
>>> x.unsqueeze(-1)
tensor([[[[-0.1986],
     [ 0.4352]]],


    [[[ 0.0971],
     [ 0.2296]]],


    [[[ 0.8339],
     [-0.5433]]]])

可以看到在指定的位置,增加了一个维度。

扩大张量:torch.Tensor.expand(*sizes) → Tensor

返回 tensor 的一个新视图,单个维度扩大为更大的尺寸。 tensor 也可以扩大为更高维,新增加的维度将附在前面。 扩大 tensor 不需要分配新内存,只是仅仅新建一个 tensor 的视图,其中通过将 stride 设为 0,一维将会扩展位更高维。任何一个一维的在不分配新内存情况下可扩展为任意的数值。

举个例子:

>>> x = torch.Tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]])
>>> x.expand(3, -1)
tensor([[1.],
    [2.],
    [3.]])

原数据是 3 行 1 列,扩大后变为 3 行 4 列,方法中填 -1 的效果与 1 一样,只有尺寸为 1 才可以扩大,如果不为 1 就无法改变,而且尺寸不为 1 的维度必须要和原来一样填写进去。

重复张量:torch.Tensor.repeat(*sizes)

沿着指定的维度重复 tensor。 不同于 expand(),本函数复制的是 tensor 中的数据。

举个例子:

>>> x = torch.Tensor([1, 2, 3])
>>> x.size()
torch.Size([3])
>>> x.repeat(4, 2)
    [1., 2., 3., 1., 2., 3.],
    [1., 2., 3., 1., 2., 3.],
    [1., 2., 3., 1., 2., 3.]])
>>> x.repeat(4, 2).size()
torch.Size([4, 6])

原数据为 1 行 3 列,按行方向扩大为原来的 4 倍,列方向扩大为原来的 2 倍,变为了 4 行 6 列。

变化时可以看成是把原数据作成一个整体,再按指定的维度和尺寸重复,变成一个 4 行 2 列的矩阵,其中的每一个单位都是相同的,再把原数据放到每个单位中。

矩阵转置:torch.t(input, out=None) → Tensor

输入一个矩阵(2维张量),并转置0, 1维。 可以被视为函数 transpose(input, 0, 1) 的简写函数。

举个例子:

>>> x = torch.randn(3, 5)
>>> x
tensor([[-1.0752, -0.9706, -0.8770, -0.4224, 0.9776],
    [ 0.2489, -0.2986, -0.7816, -0.0823, 1.1811],
    [-1.1124, 0.2160, -0.8446, 0.1762, -0.5164]])
>>> x.t()
tensor([[-1.0752, 0.2489, -1.1124],
    [-0.9706, -0.2986, 0.2160],
    [-0.8770, -0.7816, -0.8446],
    [-0.4224, -0.0823, 0.1762],
    [ 0.9776, 1.1811, -0.5164]])
>>> torch.t(x) # 另一种用法
tensor([[-1.0752, 0.2489, -1.1124],
    [-0.9706, -0.2986, 0.2160],
    [-0.8770, -0.7816, -0.8446],
    [-0.4224, -0.0823, 0.1762],
    [ 0.9776, 1.1811, -0.5164]])

必须要是 2 维的张量,也就是矩阵,才可以使用。

维度置换:torch.transpose()、torch.Tensor.permute()

  • torch.transpose(input, dim0, dim1, out=None) → Tensor

返回输入矩阵 input 的转置。交换维度 dim0 和 dim1。 输出张量与输入张量共享内存,所以改变其中一个会导致另外一个也被修改。

举个例子:

>>> x = torch.randn(2, 4, 3)
>>> x
tensor([[[-1.2502, -0.7363, 0.5534],
     [-0.2050, 3.1847, -1.6729],
     [-0.2591, -0.0860, 0.4660],
     [-1.2189, -1.1206, 0.0637]],

    [[ 1.4791, -0.7569, 2.5017],
     [ 0.0098, -1.0217, 0.8142],
     [-0.2414, -0.1790, 2.3506],
     [-0.6860, -0.2363, 1.0481]]])
>>> torch.transpose(x, 1, 2).size()
torch.Size([2, 3, 4])
>>> torch.transpose(x, 1, 2)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],
     [-0.7363, 3.1847, -0.0860, -1.1206],
     [ 0.5534, -1.6729, 0.4660, 0.0637]],

    [[ 1.4791, 0.0098, -0.2414, -0.6860],
     [-0.7569, -1.0217, -0.1790, -0.2363],
     [ 2.5017, 0.8142, 2.3506, 1.0481]]])
>>> torch.transpose(x, 0, 1).size()
torch.Size([4, 2, 3])
>>> torch.transpose(x, 0, 1)
tensor([[[-1.2502, -0.7363, 0.5534],
     [ 1.4791, -0.7569, 2.5017]],

    [[-0.2050, 3.1847, -1.6729],
     [ 0.0098, -1.0217, 0.8142]],

    [[-0.2591, -0.0860, 0.4660],
     [-0.2414, -0.1790, 2.3506]],

    [[-1.2189, -1.1206, 0.0637],
     [-0.6860, -0.2363, 1.0481]]])

可以对多维度的张量进行转置

  • torch.Tensor.permute(dims)

将 tensor 的维度换位

接着用上面的数据举个例子:

>>> x.size()
torch.Size([2, 4, 3])
>>> x.permute(2, 0, 1).size()
torch.Size([3, 2, 4])
>>> x.permute(2, 0, 1)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],
     [ 1.4791, 0.0098, -0.2414, -0.6860]],

    [[-0.7363, 3.1847, -0.0860, -1.1206],
     [-0.7569, -1.0217, -0.1790, -0.2363]],

    [[ 0.5534, -1.6729, 0.4660, 0.0637],
     [ 2.5017, 0.8142, 2.3506, 1.0481]]])

直接在方法中填入各个维度的索引,张量就会交换指定维度的尺寸,不限于两两交换。

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现代码统计工具(终极篇)
Jul 04 Python
浅谈Django自定义模板标签template_tags的用处
Dec 20 Python
利用 python 对目录下的文件进行过滤删除
Dec 27 Python
示例详解Python3 or Python2 两者之间的差异
Aug 23 Python
python解析json串与正则匹配对比方法
Dec 20 Python
在python里从协程返回一个值的示例
Feb 19 Python
Django如何防止定时任务并发浅析
May 14 Python
图文详解python安装Scrapy框架步骤
May 20 Python
Python新手学习函数默认参数设置
Jun 03 Python
Django contrib auth authenticate函数源码解析
Nov 12 Python
一篇文章教你用python画动态爱心表白
Nov 22 Python
Python加密技术之RSA加密解密的实现
Apr 08 Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
You might like
使用PHP求两个文件的相对路径
2013/06/20 PHP
php cURL和Rolling cURL并发方式比较
2013/10/30 PHP
PHP调用JAVA的WebService简单实例
2014/03/11 PHP
php轻松实现中英文混排字符串截取
2014/05/28 PHP
php中Socket创建与监听实现方法
2015/01/05 PHP
PHP获取毫秒级时间戳的方法
2015/04/15 PHP
PHP 使用位运算实现四则运算的代码
2021/03/09 PHP
Javascript hasOwnProperty 方法 & in 关键字
2008/11/26 Javascript
json原理分析及实例介绍
2012/11/29 Javascript
js实现图片旋转的三种方法
2014/04/10 Javascript
javascript中Number对象的toString()方法分析
2014/12/20 Javascript
JavaScript中的函数模式详解
2015/02/11 Javascript
Jquery实现遮罩层的方法
2015/06/08 Javascript
jQuery模拟实现天猫购物车动画效果实例代码
2017/05/25 jQuery
微信小程序 密码输入(源码下载)
2017/06/27 Javascript
微信小程序 五星评分的实现实例
2017/08/04 Javascript
基于百度地图api清除指定覆盖物(Overlay)的方法
2018/01/26 Javascript
vue数据传递--我有特殊的实现技巧
2018/03/20 Javascript
vue.js学习笔记之v-bind和v-on解析
2018/05/03 Javascript
puppeteer实现html截图的示例代码
2019/01/10 Javascript
js计算两个时间差 天 时 分 秒 毫秒的代码
2019/05/21 Javascript
世界上最短的数字判断js代码
2019/09/09 Javascript
python Crypto模块的安装与使用方法
2017/12/21 Python
详解django.contirb.auth-认证
2018/07/16 Python
Python实现微信自动好友验证,自动回复,发送群聊链接方法
2019/02/21 Python
Python2.x与3​​.x版本有哪些区别
2020/07/09 Python
a标签下载链接的简单实现
2016/09/13 HTML / CSS
DKNY品牌官网:纽约大都会时尚风格
2016/10/20 全球购物
PHP如何去执行一个SQL语句
2016/03/05 面试题
医学检验专业大学生求职信
2013/11/18 职场文书
前台领班岗位职责
2013/12/04 职场文书
学校安全生产承诺书
2014/05/23 职场文书
销售员试用期自我评价
2014/09/15 职场文书
三好学生评选事迹材料(2016精选版)
2016/02/25 职场文书
CSS3 天气图标动画效果
2021/04/06 HTML / CSS
Golang map映射的用法
2022/04/22 Golang