对Pytorch 中的contiguous理解说明


Posted in Python onMarch 03, 2021

最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解。

在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。

这些操作是:

narrow(),view(),expand()和transpose()

举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。

转置的tensor和原tensor的内存是共享的!

为了证明这一点,我们来看下面的代码:

x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0])
# print 233

可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。

也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。

注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!

当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:

RuntimeError: input is not contiguous

只要看到这个错误提示,加上contiguous()就好啦~

补充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax

gather

torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.

input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
b=torch.LongTensor([[1,2,1],
          [2,2,2],
          [2,2,2],
          [1,1,0]])
b=b.view(4,3) 
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.LongTensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,c))

输出:

tensor([[ 0.2000, 0.3000, 0.2000],
    [ 1.3000, 1.3000, 1.3000],
    [ 2.3000, 2.3000, 2.3000],
    [ 3.2000, 3.2000, 3.1000]])
tensor([[ 1.1000, 2.2000, 1.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 1.1000, 1.2000, 0.3000]])
tensor([[ 0.2000],
    [ 1.3000],
    [ 2.1000],
    [ 3.2000]])

squeeze

将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)

import torch
a=torch.randn(2,1,1,3)
print(a)
print(a.squeeze())

输出:

tensor([[[[-0.2320, 0.9513, 1.1613]]],
    [[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613],
    [ 0.0901, 0.9613, -0.9344]])

expand

扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)

import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)

输出:

tensor([[[ 0.0608],
     [ 2.2106]],
 
    [[-1.9287],
     [ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
     [ 2.2106, 2.2106, 2.2106]],
 
    [[-1.9287, -1.9287, -1.9287],
     [ 0.8748, 0.8748, 0.8748]]])

sum

size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))

输出:

tensor(60)
tensor([[ 5, 10, 15],
    [ 5, 10, 15]])

contiguous

返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。

可以通过is_contiguous查看张量内存是否连续。

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous) 
print(a.contiguous().view(4,3))

输出:

<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
tensor([[ 1,  2,  3],
    [ 4,  8, 12],
    [ 1,  2,  3],
    [ 4,  8, 12]])

softmax

假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。

对Pytorch 中的contiguous理解说明

import torch
import torch.nn.functional as F 
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=F.softmax(a,dim=1)
print(b)

输出:

tensor([[ 0.5000, 0.5000],
    [ 0.7311, 0.2689],
    [ 0.8808, 0.1192],
    [ 0.2689, 0.7311],
    [ 0.1192, 0.8808]])

max

返回最大值,或指定维度的最大值以及index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())

输出:

(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
tensor(3.3000)

argmax

返回最大值的index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())

输出:

tensor([ 2, 2, 2, 2])
tensor(11)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python获取apk文件URL地址实例
Nov 01 Python
python验证码识别教程之滑动验证码
Jun 04 Python
python之mock模块基本使用方法详解
Jun 27 Python
python变量的存储原理详解
Jul 10 Python
Python空间数据处理之GDAL读写遥感图像
Aug 01 Python
Django中ajax发送post请求 报403错误CSRF验证失败解决方案
Aug 13 Python
pandas 缺失值与空值处理的实现方法
Oct 12 Python
基于Python+Appium实现京东双十一自动领金币功能
Oct 31 Python
tensorflow之tf.record实现存浮点数数组
Feb 17 Python
Python的PIL库中getpixel方法的使用
Apr 09 Python
分享一个python的aes加密代码
Dec 22 Python
python wsgiref源码解析
Feb 06 Python
Flask中jinja2的继承实现方法及实例
Mar 03 #Python
基于PyTorch中view的用法说明
Mar 03 #Python
Python 实现劳拉游戏的实例代码(四连环、重力四子棋)
Mar 03 #Python
对pytorch中x = x.view(x.size(0), -1) 的理解说明
Mar 03 #Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
You might like
PHP输出缓存ob系列函数详解
2014/03/11 PHP
ThinkPHP使用心得分享-分页类Page的用法
2014/05/15 PHP
ThinkPHP模板中数组循环实例
2014/10/30 PHP
PHP中strcmp()和strcasecmp()函数字符串比较用法分析
2016/01/07 PHP
详解PHP处理字符串类似indexof的方法函数
2017/06/11 PHP
jquery实现文本框鼠标右击无效以及不能输入的代码
2010/11/05 Javascript
jQuery1.4.2与老版本json格式兼容的解决方法
2011/02/12 Javascript
(跨浏览器基础事件/浏览器检测/判断浏览器)经验代码分享
2013/01/24 Javascript
jquery实现漂浮在网页右侧的qq在线客服插件示例
2013/05/13 Javascript
DOM基础教程之模型中的模型节点
2015/01/19 Javascript
JS实现左右无缝轮播图代码
2016/05/01 Javascript
基于JavaScript实现跳转提示页面
2016/09/24 Javascript
Vue.js使用v-show和v-if的注意事项
2016/12/13 Javascript
JavaScript 网页中实现一个计算当年还剩多少时间的倒数计时程序
2017/01/25 Javascript
详解vue-cli项目中的proxyTable跨域问题小结
2018/02/09 Javascript
Bootstrap4如何定制自己的颜色和风格
2018/02/26 Javascript
详解JS判断页面是在手机端还是在PC端打开的方法
2019/04/26 Javascript
[01:25]2014DOTA2国际邀请赛 zhou分析LGD比赛情况
2014/07/14 DOTA
跟老齐学Python之编写类之三子类
2014/10/11 Python
python匹配两个短语之间的字符实例
2018/12/25 Python
记一次django内存异常排查及解决方法
2020/08/07 Python
HTML5网页音乐播放器的示例代码
2017/11/09 HTML / CSS
HTML实现代码雨源码及效果示例
2020/02/25 HTML / CSS
日本动漫周边服饰销售网站:Atsuko
2019/12/16 全球购物
全球精选男装和家居用品:Article
2020/04/13 全球购物
大学生个人总结的自我评价
2013/10/05 职场文书
英文版区域经理求职信
2013/10/23 职场文书
高三地理教学反思
2014/01/11 职场文书
国庆节标语大全
2014/10/08 职场文书
2014年环保局工作总结
2014/12/11 职场文书
升学宴学生致辞
2015/09/29 职场文书
银行求职信怎么写
2019/06/20 职场文书
话题作文之诚信
2019/11/28 职场文书
使用pycharm运行flask应用程序的详细教程
2021/06/07 Python
vue项目如何打包之项目打包优化(让打包的js文件变小)
2022/04/30 Vue.js
MySQL中LAG()函数和LEAD()函数的使用
2022/08/14 MySQL