Pytorch Tensor的索引与切片例子


Posted in Python onAugust 18, 2019

1. Pytorch风格的索引

根据Tensor的shape,从前往后索引,依次在每个维度上做索引。

示例代码:

import torch
 
a = torch.rand(4, 3, 28, 28)
print(a[0].shape) #取到第一个维度
print(a[0, 0].shape) # 取到二个维度
print(a[1, 2, 2, 4]) # 具体到某个元素

上述代码创建了一个shape=[4, 3, 28, 28]的Tensor,我们可以理解为4张图片,每张图片有3个通道,每个通道是28x28的图像数据。a代表这个Tensor,a后面跟着的列表[]表示对Tensor进行索引,a的维度dim = 4,决定了[]中的元素个数不能超过4个,[]中的值表示对应维度上的哪一个元素,比如 a[0]表示取第一个维度上的第一个元素,可以理解为第一张图片,a[1]表示取第一个维度上的第二个元素,可以理解为第二张图片。a[0, 0]表示取第一个维度上第一个元素的与第二个维度上的第一个元素,也就是第一张图片第一个通道的元素。a[1, 2, 2, 4]表示取第第一个维度上的第二个元素与第二个维度上的第三个元素与第三个维度上的第三个元素与第四个维度上的第5个元素,也就是第二张图片第三个通道第三行第四列的像素值是一个标量值。

输出结果:

torch.Size([3, 28, 28])
torch.Size([28, 28])
tensor(0.1076)

2. python风格的索引

示例代码:

import torch
 
# 譬如:4张图片,每张三个通道,每个通道28行28列的像素
a = torch.rand(4, 3, 28, 28)
 
# 在第一个维度上取后0和1,等同于取第一、第二张图片
print(a[:2].shape) 
 
# 在第一个维度上取0和1,在第二个维度上取0,
# 等同于取第一、第二张图片中的第一个通道
print(a[:2, :1, :, :].shape) 
 
# 在第一个维度上取0和1,在第二个维度上取1,2,
# 等同于取第一、第二张图片中的第二个通道与第三个通道
print(a[:2, 1:, :, :].shape) 
 
# 在第一个维度上取0和1,在第二个维度上取1,2,
# 等同于取第一、第二张图片中的第二个通道与第三个通道
print(a[:2, -2:, :, :].shape) 
 
# 使用step隔行采样
# 在第一、第二维度取所有元素,在第三、第四维度隔行采样
# 等同于所有图片所有通道的行列每个一行或者一列采样
# 注意:下面的代码不包括28
print(a[:, :, 0:28:2, 0:28:2].shape) 
print(a[:, :, ::2, ::2].shape) # 等同于上面语句

注意:负值的索引即表示倒数第几个元素,-2就是倒数第二个元素。

输出结果:

torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 2, 28, 28])
torch.Size([2, 2, 28, 28])

3. index_select()选择特定索引

选择特定下标有时候很有用,比如上面的a这个Tensor可以看作4张RGB(3通道)的MNIST图像,长宽都是28px。那么在第一维度上可以选择特定的图片,在第二维度上选择特定的通道,在第三维度上选择特定的行等:

# 选择第一张和第三张图
print(a.index_select(0, torch.tensor([0, 2])).shape)
 
# 选择R通道和B通道
print(a.index_select(1, torch.tensor([0, 2])).shape)
 
# 选择图像的0~8行
print(a.index_select(2, torch.arange(8)).shape)

注意:index_select()的第二个索引参数必须是Tensor类型

输出结果:

torch.Size([2, 3, 28, 28])
torch.Size([4, 2, 28, 28])
torch.Size([4, 3, 8, 28])

4. 使用 ... 索引任意多的维度

import torch
 
a = torch.rand(4, 3, 28, 28)
 
# 等与a
print(a[...].shape)
 
# 第一张图片的所有维度
print(a[0, ...].shape)
 
# 所有图片第二通道的所有维度
print(a[:, 1, ...].shape)
 
# 所有图像所有通道所有行的第一、第二列
print(a[..., :2].shape)

输出结果:

torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 3, 28, 2])

5. 使用mask索引

示例代码:

import torch
 
a = torch.randn(3, 4)
print(a)
 
# 生成a这个Tensor中大于0.5的元素的掩码
mask = a.ge(0.5)
print(mask)
 
# 取出a这个Tensor中大于0.5的元素
val = torch.masked_select(a, mask)
print(val)
print(val.shape)

输出结果:

tensor([[ 0.2055, -0.7070, 1.1201, 1.3325],
    [-1.6459, 0.9635, -0.2741, 0.0765],
    [ 0.2943, 0.1206, 1.6662, 1.5721]])
tensor([[0, 0, 1, 1],
    [0, 1, 0, 0],
    [0, 0, 1, 1]], dtype=torch.uint8)
tensor([1.1201, 1.3325, 0.9635, 1.6662, 1.5721])
torch.Size([5])

注意:最后取出的 大于0.5的Tensor的shape已经被打平。

6. take索引

take索引是在原来Tensor的shape基础上打平,然后在打平后的Tensor上进行索引。

示例代码:

import torch
 
a = torch.tensor([[3, 7, 2], [2, 8, 3]])
print(a)
print(torch.take(a, torch.tensor([0, 1, 5])))

输出结果:

tensor([[3, 7, 2],
    [2, 8, 3]])
tensor([3, 7, 3])

以上这篇Pytorch Tensor的索引与切片例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python监控网卡流量并使用graphite绘图的示例
Apr 27 Python
收藏整理的一些Python常用方法和技巧
May 18 Python
初步讲解Python中的元组概念
May 21 Python
Python+django实现简单的文件上传
Aug 17 Python
python实现协同过滤推荐算法完整代码示例
Dec 15 Python
python Crypto模块的安装与使用方法
Dec 21 Python
Python BS4库的安装与使用详解
Aug 08 Python
pygame游戏之旅 创建游戏窗口界面
Nov 20 Python
基于Django的乐观锁与悲观锁解决订单并发问题详解
Jul 31 Python
pytorch中的自定义反向传播,求导实例
Jan 06 Python
PyTorch笔记之scatter()函数的使用
Feb 12 Python
python数据分析之用sklearn预测糖尿病
Apr 22 Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
浅析PyTorch中nn.Module的使用
Aug 18 #Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
You might like
用PHP和ACCESS写聊天室(三)
2006/10/09 PHP
PHP 用数组降低程序的时间复杂度
2009/12/04 PHP
php实现微信公众平台账号自定义菜单类
2015/10/11 PHP
PHP中关键字interface和implements详解
2017/06/14 PHP
PHP自定义函数判断是否为Get、Post及Ajax提交的方法
2017/07/27 PHP
PHP _construct()函数讲解
2019/02/03 PHP
jquery.messager.js插件导致页面抖动的解决方法
2013/07/14 Javascript
跨域传值即主页面与iframe之间互相传值
2013/12/09 Javascript
jquery如何获取复选框的值
2013/12/12 Javascript
JavaScript中的异常捕捉介绍
2014/12/31 Javascript
详解JavaScript中循环控制语句的用法
2015/06/03 Javascript
js实现三张图(文)片一起切换的banner焦点图
2015/08/25 Javascript
JavaScript实现页面跳转的方式汇总
2016/05/16 Javascript
JS不用正则验证输入的字符串是否为空(包含空格)的实现代码
2016/06/14 Javascript
NodeJS中的MongoDB快速入门详细教程
2016/11/11 NodeJs
微信小程序 tabs选项卡效果的实现
2017/01/05 Javascript
vue.js,ajax渲染页面的实例
2018/02/11 Javascript
使用JS代码实现俄罗斯方块游戏
2018/08/03 Javascript
elementUI多选框反选的实现代码
2019/04/03 Javascript
Angular2使用SVG自定义图表(条形图、折线图)组件示例
2019/05/10 Javascript
Vue 实现html中根据类型显示内容
2019/10/28 Javascript
零基础写python爬虫之抓取百度贴吧代码分享
2014/11/06 Python
通过数据库向Django模型添加字段的示例
2015/07/21 Python
Python判断文本中消息重复次数的方法
2016/04/27 Python
Python基础教程之异常详解
2019/01/10 Python
基于Numba提高python运行效率过程解析
2020/03/02 Python
jupyter notebook中新建cell的方法与快捷键操作
2020/04/22 Python
雪花秀美国官方网站:韩国著名草本护肤化妆品品牌
2016/10/19 全球购物
与世界上最好的跑步专业品牌合作:Fleet Feet
2019/03/22 全球购物
香奈儿美国官网:CHANEL美国
2020/05/20 全球购物
3个CCIE对一个工程师的面试题
2012/05/06 面试题
Java语言程序设计测试题判断题部分
2013/01/06 面试题
会计专业毕业生自荐信范文
2013/12/20 职场文书
师德建设实施方案
2014/03/21 职场文书
重大事项社会稳定风险评估方案
2014/06/15 职场文书
SpringCloud超详细讲解Feign声明式服务调用
2022/06/21 Java/Android