对Pytorch 中的contiguous理解说明


Posted in Python onMarch 03, 2021

最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解。

在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。

这些操作是:

narrow(),view(),expand()和transpose()

举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。

转置的tensor和原tensor的内存是共享的!

为了证明这一点,我们来看下面的代码:

x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0])
# print 233

可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。

也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。

注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!

当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:

RuntimeError: input is not contiguous

只要看到这个错误提示,加上contiguous()就好啦~

补充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax

gather

torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.

input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
b=torch.LongTensor([[1,2,1],
          [2,2,2],
          [2,2,2],
          [1,1,0]])
b=b.view(4,3) 
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.LongTensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,c))

输出:

tensor([[ 0.2000, 0.3000, 0.2000],
    [ 1.3000, 1.3000, 1.3000],
    [ 2.3000, 2.3000, 2.3000],
    [ 3.2000, 3.2000, 3.1000]])
tensor([[ 1.1000, 2.2000, 1.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 1.1000, 1.2000, 0.3000]])
tensor([[ 0.2000],
    [ 1.3000],
    [ 2.1000],
    [ 3.2000]])

squeeze

将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)

import torch
a=torch.randn(2,1,1,3)
print(a)
print(a.squeeze())

输出:

tensor([[[[-0.2320, 0.9513, 1.1613]]],
    [[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613],
    [ 0.0901, 0.9613, -0.9344]])

expand

扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)

import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)

输出:

tensor([[[ 0.0608],
     [ 2.2106]],
 
    [[-1.9287],
     [ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
     [ 2.2106, 2.2106, 2.2106]],
 
    [[-1.9287, -1.9287, -1.9287],
     [ 0.8748, 0.8748, 0.8748]]])

sum

size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))

输出:

tensor(60)
tensor([[ 5, 10, 15],
    [ 5, 10, 15]])

contiguous

返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。

可以通过is_contiguous查看张量内存是否连续。

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous) 
print(a.contiguous().view(4,3))

输出:

<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
tensor([[ 1,  2,  3],
    [ 4,  8, 12],
    [ 1,  2,  3],
    [ 4,  8, 12]])

softmax

假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。

对Pytorch 中的contiguous理解说明

import torch
import torch.nn.functional as F 
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=F.softmax(a,dim=1)
print(b)

输出:

tensor([[ 0.5000, 0.5000],
    [ 0.7311, 0.2689],
    [ 0.8808, 0.1192],
    [ 0.2689, 0.7311],
    [ 0.1192, 0.8808]])

max

返回最大值,或指定维度的最大值以及index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())

输出:

(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
tensor(3.3000)

argmax

返回最大值的index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())

输出:

tensor([ 2, 2, 2, 2])
tensor(11)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
详解Python中time()方法的使用的教程
May 22 Python
深入学习python的yield和generator
Mar 10 Python
谈谈如何手动释放Python的内存
Dec 17 Python
Django + Uwsgi + Nginx 实现生产环境部署的方法
Jun 20 Python
python 递归深度优先搜索与广度优先搜索算法模拟实现
Oct 22 Python
python3使用pandas获取股票数据的方法
Dec 22 Python
python爬取内容存入Excel实例
Feb 20 Python
Python图像处理之图片文字识别功能(OCR)
Jul 30 Python
Python中的特殊方法以及应用详解
Sep 20 Python
python爬虫爬取淘宝商品比价(附淘宝反爬虫机制解决小办法)
Dec 03 Python
正确的理解和使用Django信号(Signals)
Apr 14 Python
pandas 实现将NaN转换为None
May 14 Python
Flask中jinja2的继承实现方法及实例
Mar 03 #Python
基于PyTorch中view的用法说明
Mar 03 #Python
Python 实现劳拉游戏的实例代码(四连环、重力四子棋)
Mar 03 #Python
对pytorch中x = x.view(x.size(0), -1) 的理解说明
Mar 03 #Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
You might like
jQuery获取json后使用zy_tmpl生成下拉菜单
2015/03/27 PHP
PHP弱类型的安全问题详细总结
2016/09/25 PHP
PHP convert_uudecode()函数讲解
2019/02/14 PHP
利用404错误页面实现UrlRewrite的实现代码
2008/08/20 Javascript
JavaScript(JS) 压缩 / 混淆 / 格式化 批处理工具
2010/12/10 Javascript
解析javascript系统错误:-1072896658的解决办法
2013/07/08 Javascript
AngularJS入门知识之MVW类框架的编程思想探讨
2014/12/08 Javascript
JS+CSS实现带关闭按钮DIV弹出窗口的方法
2015/02/27 Javascript
基于javascript实现最简单的选项卡切换效果
2016/05/16 Javascript
vue-router:嵌套路由的使用方法
2017/02/21 Javascript
react实现一个优雅的图片占位模块组件详解
2017/10/30 Javascript
小程序如何构建骨架屏
2019/05/29 Javascript
一次微信小程序内地图的使用实战记录
2019/09/09 Javascript
jQuery实现的解析本地 XML 文档操作示例
2020/04/30 jQuery
微信小程序实现文件预览
2020/10/22 Javascript
微信小程序实现点赞业务
2021/02/10 Javascript
Python中使用Tkinter模块创建GUI程序实例
2015/01/14 Python
Python中使用PIL库实现图片高斯模糊实例
2015/02/08 Python
window下eclipse安装python插件教程
2017/04/24 Python
Python将list中的string批量转化成int/float的方法
2018/06/26 Python
python异步编程 使用yield from过程解析
2019/09/25 Python
借助Paramiko通过Python实现linux远程登陆及sftp的操作
2020/03/16 Python
当当网官方旗舰店:中国图书销售夺金品牌
2018/04/02 全球购物
求职推荐信
2013/10/28 职场文书
大学毕业生文采飞扬的自我鉴定
2013/12/03 职场文书
小学防溺水制度
2014/01/29 职场文书
优秀毕业生自我鉴定
2014/02/11 职场文书
行政工作个人的自我评价
2014/02/13 职场文书
英语专业求职信
2014/07/08 职场文书
励志演讲稿200字
2014/08/21 职场文书
公共艺术专业自荐信
2014/09/01 职场文书
阳光体育运动标语口号
2015/12/26 职场文书
MySQL Shell的介绍以及安装
2021/04/24 MySQL
关于vue中如何监听数组变化
2021/04/28 Vue.js
Python学习之包与模块详解
2022/03/19 Python
SQL Server中使用表变量和临时表
2022/05/20 SQL Server