PyTorch中Tensor的拼接与拆分的实现


Posted in Python onAugust 18, 2019

拼接张量:torch.cat() 、torch.stack()

  1. torch.cat(inputs, dimension=0) → Tensor

在给定维度上对输入的张量序列 seq 进行连接操作

举个例子:

>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764]])
>>> torch.cat((x, x, x), 0) # 在 0 维(纵向)进行拼接
tensor([[-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764],
    [-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764],
    [-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764]])
>>> torch.cat((x, x, x), 1) # 在 1 维(横向)进行拼接
tensor([[-0.1997, -0.6900, 0.7039, -0.1997, -0.6900, 0.7039, -0.1997, -0.6900,
     0.7039],
    [ 0.0268, -1.0140, -2.9764, 0.0268, -1.0140, -2.9764, 0.0268, -1.0140,
     -2.9764]])
>>> y1 = torch.randn(5, 3, 6)
>>> y2 = torch.randn(5, 3, 6)
>>> torch.cat([y1, y2], 2).size()
torch.Size([5, 3, 12])
>>> torch.cat([y1, y2], 1).size()
torch.Size([5, 6, 6])

对于需要拼接的张量,维度数量必须相同,进行拼接的维度的尺寸可以不同,但是其它维度的尺寸必须相同。

  • torch.stack(sequence, dim=0)

沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状

举个例子:

>>> x1 = torch.randn(2, 3)
>>> x2 = torch.randn(2, 3)
>>> torch.stack((x1, x2), 0).size() # 在 0 维插入一个维度,进行区分拼接
torch.Size([2, 2, 3])
>>> torch.stack((x1, x2), 1).size() # 在 1 维插入一个维度,进行组合拼接
torch.Size([2, 2, 3])
>>> torch.stack((x1, x2), 2).size()
torch.Size([2, 3, 2])
>>> torch.stack((x1, x2), 0)
tensor([[[-0.3499, -0.6124, 1.4332],
     [ 0.1516, -1.5439, -0.1758]],

    [[-0.4678, -1.1430, -0.5279],
     [-0.4917, -0.6504, 2.2512]]])
>>> torch.stack((x1, x2), 1)
tensor([[[-0.3499, -0.6124, 1.4332],
     [-0.4678, -1.1430, -0.5279]],

    [[ 0.1516, -1.5439, -0.1758],
     [-0.4917, -0.6504, 2.2512]]])
>>> torch.stack((x1, x2), 2)
tensor([[[-0.3499, -0.4678],
     [-0.6124, -1.1430],
     [ 1.4332, -0.5279]],

    [[ 0.1516, -0.4917],
     [-1.5439, -0.6504],
     [-0.1758, 2.2512]]])

把相同形状的张量合并,并根据提供的维度序列在相应位置插入维度,方法会根据位置来排列数据。代码中,根据第 0 维和第 1 维来进行合并时,虽然合并后的张量维度和尺寸相等,但是数据的位置并不是相同的。

拆分张量:torch.split()、torch.chunk()

  • torch.split(tensor, split_size, dim=0)

将输入张量分割成相等形状的 chunks(如果可分)。 如果沿指定维的张量形状大小不能被 split_size 整分, 则最后一个分块会小于其它分块。

举个例子:

>>> x = torch.randn(3, 10, 6)
>>> a, b, c = x.split(1, 0) # 在 0 维进行间隔维 1 的拆分
>>> a.size(), b.size(), c.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> d, e = x.split(2, 0) # 在 0 维进行间隔维 2 的拆分
>>> d.size(), e.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

把张量在 0 维度上以间隔 1 来拆分时,其中 x 在 0 维度上的尺寸为 3,就可以分成 3 份。

把张量在 0 维度上以间隔 2 来拆分时,只能分成 2 份,且只能把前面部分先以间隔 2 来拆分,后面不足 2 的部分就直接作为一个分块。

  • torch.chunk(tensor, chunks, dim=0)

在给定维度(轴)上将输入张量进行分块儿

直接用上面的数据来举个例子:

>>> l, m, n = x.chunk(3, 0) # 在 0 维上拆分成 3 份
>>> l.size(), m.size(), n.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> u, v = x.chunk(2, 0) # 在 0 维上拆分成 2 份
>>> u.size(), v.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

把张量在 0 维度上拆分成 3 部分时,因为尺寸正好为 3,所以每个分块的间隔相等,都为 1。

把张量在 0 维度上拆分成 2 部分时,无法平均分配,以上面的结果来看,可以看成是,用 0 维度的尺寸除以需要拆分的份数,把余数作为最后一个分块的间隔大小,再把前面的分块以相同的间隔拆分。

在某一维度上拆分的份数不能比这一维度的尺寸大

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python数组条件过滤filter函数使用示例
Jul 22 Python
CentOS 6.X系统下升级Python2.6到Python2.7 的方法
Oct 12 Python
Python安装pycurl失败的解决方法
Oct 15 Python
Python3.4学习笔记之常用操作符,条件分支和循环用法示例
Mar 01 Python
python实现最小二乘法线性拟合
Jul 19 Python
Django项目基础配置和基本使用过程解析
Nov 25 Python
python打包生成so文件的实现
Oct 30 Python
python爬虫scrapy图书分类实例讲解
Nov 23 Python
Python 实现RSA加解密文本文件
Dec 30 Python
python程序实现BTC(比特币)挖矿的完整代码
Jan 20 Python
python如何正确使用yield
May 21 Python
Jupyter Notebook 如何修改字体和大小以及更改字体样式
Jun 03 Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
You might like
咖啡与牛奶
2021/03/03 冲泡冲煮
PHP源码之explode使用说明
2011/08/05 PHP
PHP缓存技术的多种方法小结
2012/08/14 PHP
分享下页面关键字抓取www.icbase.com站点代码(带asp.net参数的)
2014/01/30 PHP
浅谈PHP中单引号和双引号到底有啥区别呢?
2015/03/04 PHP
php strftime函数的详细用法
2018/06/21 PHP
laravel-admin 在列表页添加自定义按钮的例子
2019/09/30 PHP
js以对象为索引的关联数组
2010/07/04 Javascript
js 编程笔记 无名函数
2011/06/28 Javascript
JS重要知识点小结
2011/11/06 Javascript
AngularJS 如何在控制台进行错误调试
2016/06/07 Javascript
js中获取jsp表单中radio类型的值简单实例
2016/08/15 Javascript
Angularjs中controller的三种写法分享
2016/09/21 Javascript
微信小程序上滑加载下拉刷新(onscrollLower)分批加载数据(一)
2017/05/11 Javascript
Vue.js实现分页查询功能
2020/11/15 Javascript
JavaScript简单实现关键字文本搜索高亮显示功能示例
2018/07/25 Javascript
elementUI 动态生成几行几列的方法示例
2019/07/11 Javascript
countup.js实现数字动态叠加效果
2019/10/17 Javascript
使用JS监听键盘按下事件(keydown event)
2019/11/07 Javascript
python实现通过shelve修改对象实例
2014/09/26 Python
Python获取系统默认字符编码的方法
2015/06/04 Python
python使用参数对嵌套字典进行取值的方法
2019/04/26 Python
Python Opencv提取图片中某种颜色组成的图形的方法
2019/09/19 Python
在Tensorflow中查看权重的实现
2020/01/24 Python
python 日志模块 日志等级设置失效的解决方案
2020/05/26 Python
如何实现jdbc性能优化
2012/07/30 面试题
木工主管岗位职责
2013/12/08 职场文书
手术室护士长竞聘书
2014/03/31 职场文书
遗嘱继承公证书
2014/04/09 职场文书
房产转让协议书
2014/04/11 职场文书
优秀党务工作者事迹材料
2014/05/07 职场文书
2015年学生会个人工作总结
2015/04/09 职场文书
2015年小学二年级班主任工作总结
2015/05/21 职场文书
2015年小学语文教师工作总结
2015/10/23 职场文书
2016简单的租房合同范本
2016/03/18 职场文书
创业计划书之冷饮店
2019/09/27 职场文书