对Pytorch 中的contiguous理解说明


Posted in Python onMarch 03, 2021

最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解。

在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。

这些操作是:

narrow(),view(),expand()和transpose()

举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。

转置的tensor和原tensor的内存是共享的!

为了证明这一点,我们来看下面的代码:

x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0])
# print 233

可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。

也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。

注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!

当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:

RuntimeError: input is not contiguous

只要看到这个错误提示,加上contiguous()就好啦~

补充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax

gather

torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.

input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
b=torch.LongTensor([[1,2,1],
          [2,2,2],
          [2,2,2],
          [1,1,0]])
b=b.view(4,3) 
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.LongTensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,c))

输出:

tensor([[ 0.2000, 0.3000, 0.2000],
    [ 1.3000, 1.3000, 1.3000],
    [ 2.3000, 2.3000, 2.3000],
    [ 3.2000, 3.2000, 3.1000]])
tensor([[ 1.1000, 2.2000, 1.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 1.1000, 1.2000, 0.3000]])
tensor([[ 0.2000],
    [ 1.3000],
    [ 2.1000],
    [ 3.2000]])

squeeze

将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)

import torch
a=torch.randn(2,1,1,3)
print(a)
print(a.squeeze())

输出:

tensor([[[[-0.2320, 0.9513, 1.1613]]],
    [[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613],
    [ 0.0901, 0.9613, -0.9344]])

expand

扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)

import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)

输出:

tensor([[[ 0.0608],
     [ 2.2106]],
 
    [[-1.9287],
     [ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
     [ 2.2106, 2.2106, 2.2106]],
 
    [[-1.9287, -1.9287, -1.9287],
     [ 0.8748, 0.8748, 0.8748]]])

sum

size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))

输出:

tensor(60)
tensor([[ 5, 10, 15],
    [ 5, 10, 15]])

contiguous

返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。

可以通过is_contiguous查看张量内存是否连续。

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous) 
print(a.contiguous().view(4,3))

输出:

<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
tensor([[ 1,  2,  3],
    [ 4,  8, 12],
    [ 1,  2,  3],
    [ 4,  8, 12]])

softmax

假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。

对Pytorch 中的contiguous理解说明

import torch
import torch.nn.functional as F 
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=F.softmax(a,dim=1)
print(b)

输出:

tensor([[ 0.5000, 0.5000],
    [ 0.7311, 0.2689],
    [ 0.8808, 0.1192],
    [ 0.2689, 0.7311],
    [ 0.1192, 0.8808]])

max

返回最大值,或指定维度的最大值以及index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())

输出:

(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
tensor(3.3000)

argmax

返回最大值的index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())

输出:

tensor([ 2, 2, 2, 2])
tensor(11)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python使用sorted函数对列表进行排序的方法
Apr 04 Python
高效测试用例组织算法pairwise之Python实现方法
Jul 19 Python
python的Tqdm模块的使用
Jan 10 Python
python之pandas用法大全
Mar 13 Python
python flask框架实现重定向功能示例
Jul 02 Python
python从list列表中选出一个数和其对应的坐标方法
Jul 20 Python
用openCV和Python 实现图片对比,并标识出不同点的方式
Dec 19 Python
解决TensorFlow模型恢复报错的问题
Feb 06 Python
Python反爬虫伪装浏览器进行爬虫
Feb 28 Python
Xadmin+rules实现多选行权限方式(级联效果)
Apr 07 Python
Eclipse配置python默认头过程图解
Apr 26 Python
Python如何使用队列方式实现多线程爬虫
May 12 Python
Flask中jinja2的继承实现方法及实例
Mar 03 #Python
基于PyTorch中view的用法说明
Mar 03 #Python
Python 实现劳拉游戏的实例代码(四连环、重力四子棋)
Mar 03 #Python
对pytorch中x = x.view(x.size(0), -1) 的理解说明
Mar 03 #Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
You might like
虹吸式咖啡壶操作
2021/03/03 冲泡冲煮
Apache中php.ini的设置方法
2013/02/28 PHP
JavaScript 学习笔记(十一)
2010/01/19 Javascript
禁用Tab键JS代码兼容Firefox和IE
2014/04/18 Javascript
Jquery解析Json格式数据过程代码
2014/10/17 Javascript
Ionic项目中Native Camera的使用方法
2017/06/07 Javascript
Babel 入门教程学习笔记
2018/06/13 Javascript
深入浅出理解JavaScript高级定时器原理与用法
2018/08/02 Javascript
基于javascript的无缝滚动动画实现2
2020/08/07 Javascript
[09:31]2016国际邀请赛中国区预选赛Yao赛后采访 答题送礼
2016/06/27 DOTA
[01:01:24]LGD vs Fnatic 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
一个计算身份证号码校验位的Python小程序
2014/08/15 Python
Python使用chardet判断字符编码
2015/05/09 Python
python实现解数独程序代码
2017/04/12 Python
python urllib爬取百度云连接的实例代码
2017/06/19 Python
Python数据结构之双向链表的定义与使用方法示例
2018/01/16 Python
Python爬虫爬取一个网页上的图片地址实例代码
2018/01/16 Python
python监控进程脚本
2018/04/12 Python
python读取tif图片时保留其16bit的编码格式实例
2020/01/13 Python
详解移动端Html5页面中1px边框的几种解决方法
2018/07/24 HTML / CSS
英国婚礼商城:Wedding Mall
2019/11/02 全球购物
应届生求职信写作技巧
2013/10/24 职场文书
写好自荐信的要点
2013/11/06 职场文书
学习十八大报告感言
2014/02/04 职场文书
借款协议书
2014/04/12 职场文书
2014年行政人事工作总结
2014/12/09 职场文书
超市员工辞职信范文
2015/05/12 职场文书
学生早退检讨书(范文)
2019/08/19 职场文书
遇事可以测出您的见识与格局
2019/09/16 职场文书
基于Python实现的购物商城管理系统
2021/04/27 Python
比较几种Redis集群方案
2021/06/21 Redis
html输入两个数实现加减乘除功能
2021/07/01 HTML / CSS
Javascript使用integrity属性进行安全验证
2021/11/07 Javascript
Netty分布式客户端接入流程初始化源码分析
2022/03/25 Java/Android
tomcat的catalina.out日志按自定义时间格式进行分割的操作方法
2022/04/02 Servers
MySql统计函数COUNT的具体使用详解
2022/08/14 MySQL