PyTorch中Tensor的拼接与拆分的实现


Posted in Python onAugust 18, 2019

拼接张量:torch.cat() 、torch.stack()

  1. torch.cat(inputs, dimension=0) → Tensor

在给定维度上对输入的张量序列 seq 进行连接操作

举个例子:

>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764]])
>>> torch.cat((x, x, x), 0) # 在 0 维(纵向)进行拼接
tensor([[-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764],
    [-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764],
    [-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764]])
>>> torch.cat((x, x, x), 1) # 在 1 维(横向)进行拼接
tensor([[-0.1997, -0.6900, 0.7039, -0.1997, -0.6900, 0.7039, -0.1997, -0.6900,
     0.7039],
    [ 0.0268, -1.0140, -2.9764, 0.0268, -1.0140, -2.9764, 0.0268, -1.0140,
     -2.9764]])
>>> y1 = torch.randn(5, 3, 6)
>>> y2 = torch.randn(5, 3, 6)
>>> torch.cat([y1, y2], 2).size()
torch.Size([5, 3, 12])
>>> torch.cat([y1, y2], 1).size()
torch.Size([5, 6, 6])

对于需要拼接的张量,维度数量必须相同,进行拼接的维度的尺寸可以不同,但是其它维度的尺寸必须相同。

  • torch.stack(sequence, dim=0)

沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状

举个例子:

>>> x1 = torch.randn(2, 3)
>>> x2 = torch.randn(2, 3)
>>> torch.stack((x1, x2), 0).size() # 在 0 维插入一个维度,进行区分拼接
torch.Size([2, 2, 3])
>>> torch.stack((x1, x2), 1).size() # 在 1 维插入一个维度,进行组合拼接
torch.Size([2, 2, 3])
>>> torch.stack((x1, x2), 2).size()
torch.Size([2, 3, 2])
>>> torch.stack((x1, x2), 0)
tensor([[[-0.3499, -0.6124, 1.4332],
     [ 0.1516, -1.5439, -0.1758]],

    [[-0.4678, -1.1430, -0.5279],
     [-0.4917, -0.6504, 2.2512]]])
>>> torch.stack((x1, x2), 1)
tensor([[[-0.3499, -0.6124, 1.4332],
     [-0.4678, -1.1430, -0.5279]],

    [[ 0.1516, -1.5439, -0.1758],
     [-0.4917, -0.6504, 2.2512]]])
>>> torch.stack((x1, x2), 2)
tensor([[[-0.3499, -0.4678],
     [-0.6124, -1.1430],
     [ 1.4332, -0.5279]],

    [[ 0.1516, -0.4917],
     [-1.5439, -0.6504],
     [-0.1758, 2.2512]]])

把相同形状的张量合并,并根据提供的维度序列在相应位置插入维度,方法会根据位置来排列数据。代码中,根据第 0 维和第 1 维来进行合并时,虽然合并后的张量维度和尺寸相等,但是数据的位置并不是相同的。

拆分张量:torch.split()、torch.chunk()

  • torch.split(tensor, split_size, dim=0)

将输入张量分割成相等形状的 chunks(如果可分)。 如果沿指定维的张量形状大小不能被 split_size 整分, 则最后一个分块会小于其它分块。

举个例子:

>>> x = torch.randn(3, 10, 6)
>>> a, b, c = x.split(1, 0) # 在 0 维进行间隔维 1 的拆分
>>> a.size(), b.size(), c.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> d, e = x.split(2, 0) # 在 0 维进行间隔维 2 的拆分
>>> d.size(), e.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

把张量在 0 维度上以间隔 1 来拆分时,其中 x 在 0 维度上的尺寸为 3,就可以分成 3 份。

把张量在 0 维度上以间隔 2 来拆分时,只能分成 2 份,且只能把前面部分先以间隔 2 来拆分,后面不足 2 的部分就直接作为一个分块。

  • torch.chunk(tensor, chunks, dim=0)

在给定维度(轴)上将输入张量进行分块儿

直接用上面的数据来举个例子:

>>> l, m, n = x.chunk(3, 0) # 在 0 维上拆分成 3 份
>>> l.size(), m.size(), n.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> u, v = x.chunk(2, 0) # 在 0 维上拆分成 2 份
>>> u.size(), v.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

把张量在 0 维度上拆分成 3 部分时,因为尺寸正好为 3,所以每个分块的间隔相等,都为 1。

把张量在 0 维度上拆分成 2 部分时,无法平均分配,以上面的结果来看,可以看成是,用 0 维度的尺寸除以需要拆分的份数,把余数作为最后一个分块的间隔大小,再把前面的分块以相同的间隔拆分。

在某一维度上拆分的份数不能比这一维度的尺寸大

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的exec、eval使用实例
Sep 23 Python
跟老齐学Python之关于循环的小伎俩
Oct 02 Python
Python的collections模块中namedtuple结构使用示例
Jul 07 Python
Linux 修改Python命令的方法示例
Dec 03 Python
详解Python解决抓取内容乱码问题(decode和encode解码)
Mar 29 Python
python实现图片压缩代码实例
Aug 12 Python
解决pandas展示数据输出时列名不能对齐的问题
Nov 18 Python
win10安装tesserocr配置 Python使用tesserocr识别字母数字验证码
Jan 16 Python
Python爬虫制作翻译程序的示例代码
Feb 22 Python
Python实现DBSCAN聚类算法并样例测试
Jun 22 Python
一起来学习Python的元组和列表
Mar 13 Python
pytest实现多进程与多线程运行超好用的插件
Jul 15 Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
You might like
php 生成静态页面的办法与实现代码详细版
2010/02/15 PHP
CentOS 安装 PHP5.5+Redis+XDebug+Nginx+MySQL全纪录
2015/03/25 PHP
CodeIgniter中使用Smarty3基本配置
2015/06/29 PHP
PHP MPDF中文乱码的解决方式
2015/12/08 PHP
完美解决在ThinkPHP控制器中命名空间的问题
2017/05/05 PHP
thinkPHP中U方法加密传递参数功能示例
2018/05/29 PHP
PHP的静态方法与普通方法用法实例分析
2019/09/26 PHP
IE不出现Flash激活框的小发现的js实现方法
2007/09/07 Javascript
面向对象的Javascript之三(封装和信息隐藏)
2012/01/27 Javascript
javascript获取web应用根目录的方法
2014/02/12 Javascript
jQuery判断div随滚动条滚动到一定位置后停止
2014/04/02 Javascript
js生成验证码并直接在前端判断
2015/05/15 Javascript
nodeJs内存泄漏问题详解
2016/09/05 NodeJs
jquery使用EasyUI Tree异步加载JSON数据(生成树)
2017/02/11 Javascript
JS文件/图片从电脑里面拖拽到浏览器上传文件/图片
2017/03/08 Javascript
Angular项目中$scope.$apply()方法的使用详解
2017/07/26 Javascript
vue单页缓存方案分析及实现
2018/09/25 Javascript
JAVA面试题 static关键字详解
2019/07/16 Javascript
JS函数参数的传递与同名参数实例分析
2020/03/16 Javascript
bootstrap实现tab选项卡切换
2020/08/09 Javascript
vue页面引入three.js实现3d动画场景操作
2020/08/10 Javascript
前端vue如何使用高德地图
2020/11/05 Javascript
深入了解Vue3模板编译原理
2020/11/19 Vue.js
python中json格式数据输出的简单实现方法
2016/10/31 Python
linux环境中没有网络怎么下载python
2019/07/07 Python
详解Anconda环境下载python包的教程(图形界面+命令行+pycharm安装)
2019/11/11 Python
用python拟合等角螺线的实现示例
2019/12/27 Python
selenium中get_cookies()和add_cookie()的用法详解
2020/01/06 Python
Python实现计算长方形面积(带参数函数demo)
2020/01/18 Python
Python读取表格类型文件代码实例
2020/02/17 Python
ktv收银员岗位职责
2013/12/16 职场文书
团工委书记自荐书范文
2013/12/17 职场文书
忠诚教育心得体会
2014/09/03 职场文书
医药公司开票员岗位职责
2015/04/15 职场文书
观看焦裕禄观后感
2015/06/09 职场文书
2015年秋季开学典礼校长致辞
2015/07/16 职场文书