对Pytorch 中的contiguous理解说明


Posted in Python onMarch 03, 2021

最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解。

在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。

这些操作是:

narrow(),view(),expand()和transpose()

举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。

转置的tensor和原tensor的内存是共享的!

为了证明这一点,我们来看下面的代码:

x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0])
# print 233

可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。

也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。

注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!

当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:

RuntimeError: input is not contiguous

只要看到这个错误提示,加上contiguous()就好啦~

补充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax

gather

torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.

input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
b=torch.LongTensor([[1,2,1],
          [2,2,2],
          [2,2,2],
          [1,1,0]])
b=b.view(4,3) 
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.LongTensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,c))

输出:

tensor([[ 0.2000, 0.3000, 0.2000],
    [ 1.3000, 1.3000, 1.3000],
    [ 2.3000, 2.3000, 2.3000],
    [ 3.2000, 3.2000, 3.1000]])
tensor([[ 1.1000, 2.2000, 1.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 1.1000, 1.2000, 0.3000]])
tensor([[ 0.2000],
    [ 1.3000],
    [ 2.1000],
    [ 3.2000]])

squeeze

将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)

import torch
a=torch.randn(2,1,1,3)
print(a)
print(a.squeeze())

输出:

tensor([[[[-0.2320, 0.9513, 1.1613]]],
    [[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613],
    [ 0.0901, 0.9613, -0.9344]])

expand

扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)

import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)

输出:

tensor([[[ 0.0608],
     [ 2.2106]],
 
    [[-1.9287],
     [ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
     [ 2.2106, 2.2106, 2.2106]],
 
    [[-1.9287, -1.9287, -1.9287],
     [ 0.8748, 0.8748, 0.8748]]])

sum

size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))

输出:

tensor(60)
tensor([[ 5, 10, 15],
    [ 5, 10, 15]])

contiguous

返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。

可以通过is_contiguous查看张量内存是否连续。

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous) 
print(a.contiguous().view(4,3))

输出:

<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
tensor([[ 1,  2,  3],
    [ 4,  8, 12],
    [ 1,  2,  3],
    [ 4,  8, 12]])

softmax

假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。

对Pytorch 中的contiguous理解说明

import torch
import torch.nn.functional as F 
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=F.softmax(a,dim=1)
print(b)

输出:

tensor([[ 0.5000, 0.5000],
    [ 0.7311, 0.2689],
    [ 0.8808, 0.1192],
    [ 0.2689, 0.7311],
    [ 0.1192, 0.8808]])

max

返回最大值,或指定维度的最大值以及index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())

输出:

(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
tensor(3.3000)

argmax

返回最大值的index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())

输出:

tensor([ 2, 2, 2, 2])
tensor(11)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
用Python的线程来解决生产者消费问题的示例
Apr 02 Python
Python基础语法(Python基础知识点)
Feb 28 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 Python
python获取代码运行时间的实例代码
Jun 11 Python
利用ctypes获取numpy数组的指针方法
Feb 12 Python
Python中使用遍历在列表中添加字典遇到的坑
Feb 27 Python
python实现微信防撤回神器
Apr 29 Python
Python3.6+Django2.0以上 xadmin站点的配置和使用教程图解
Jun 04 Python
python UDP(udp)协议发送和接收的实例
Jul 22 Python
简单介绍python封装的基本知识
Aug 10 Python
python实现扫雷游戏的示例
Oct 20 Python
使用PyCharm官方中文语言包汉化PyCharm
Nov 18 Python
Flask中jinja2的继承实现方法及实例
Mar 03 #Python
基于PyTorch中view的用法说明
Mar 03 #Python
Python 实现劳拉游戏的实例代码(四连环、重力四子棋)
Mar 03 #Python
对pytorch中x = x.view(x.size(0), -1) 的理解说明
Mar 03 #Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
You might like
YII路径的用法总结
2014/07/09 PHP
php实现中文字符截取防乱码方法汇总
2015/04/29 PHP
Yii中srbac权限扩展模块工作原理与用法分析
2016/07/14 PHP
JavaScript 无符号右移赋值操作
2009/04/17 Javascript
JavaScript的eval JSON object问题
2009/11/15 Javascript
JQuery 引发两次$(document.ready)事件
2010/01/15 Javascript
Ext.get() 和 Ext.query()组合使用实现最灵活的取元素方式
2011/09/26 Javascript
json数据处理技巧(字段带空格、增加字段、排序等等)
2013/06/14 Javascript
简单实用jquery版三级联动select示例
2013/07/04 Javascript
【经典源码收藏】jQuery实用代码片段(筛选,搜索,样式,清除默认值,多选等)
2016/06/07 Javascript
ES6之模版字符串的具体使用
2018/05/17 Javascript
关于layui时间回显问题的解决方法
2019/09/24 Javascript
[02:25]DOTA2英雄基础教程 虚空假面
2014/01/02 DOTA
Python实现在Linux系统下更改当前进程运行用户
2015/02/04 Python
用Python的线程来解决生产者消费问题的示例
2015/04/02 Python
利用Pandas 创建空的DataFrame方法
2018/04/08 Python
Python根据文件名批量转移图片的方法
2018/10/21 Python
Flask框架web开发之零基础入门
2018/12/10 Python
python实现Dijkstra静态寻路算法
2019/01/17 Python
python替换字符串中的子串图文步骤
2019/06/19 Python
python利用itertools生成密码字典并多线程撞库破解rar密码
2019/08/12 Python
Python如何基于rsa模块实现非对称加密与解密
2020/01/03 Python
pytorch实现Tensor变量之间的转换
2020/02/17 Python
Python3自动生成MySQL数据字典的markdown文本的实现
2020/05/07 Python
新秀丽拉杆箱美国官方网站:Samsonite美国
2016/07/25 全球购物
LivingSocial爱尔兰:爱尔兰本地优惠
2018/08/10 全球购物
Linux如何命名文件--使用文件名时应注意
2014/05/29 面试题
大学生毕业自我鉴定范文
2013/11/03 职场文书
个人求职信范文分享
2014/01/31 职场文书
幼儿园招生广告
2014/03/19 职场文书
个性婚礼策划方案
2014/05/17 职场文书
疾病防治方案
2014/05/31 职场文书
爱护公共设施倡议书
2014/08/29 职场文书
离婚被告代理词
2015/05/23 职场文书
幼儿园开学家长寄语(2015秋季)
2015/05/27 职场文书
Python测试框架pytest核心库pluggy详解
2022/08/05 Golang