pytorch下的unsqueeze和squeeze的用法说明


Posted in Python onFebruary 06, 2021

#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉

#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加

import torch 
a=torch.rand(2,3,1)       
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1]) 
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze().size())    #torch.Size([2, 3]) 
print(a.squeeze(0).size())   #torch.Size([2, 3, 1])
 
print(a.squeeze(-1).size())   #torch.Size([2, 3])
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze(-2).size())   #torch.Size([2, 3, 1])
print(a.squeeze(-3).size())   #torch.Size([2, 3, 1])
print(a.squeeze(1).size())   #torch.Size([2, 3, 1])
print(a.squeeze(2).size())   #torch.Size([2, 3])
print(a.squeeze(3).size())   #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
 
print(a.unsqueeze().size())   #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size())  #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size())  #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size())    #torch.Size([2, 3])

补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结

学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。

1、unsqueeze()和squeeze()

torch.unsqueeze(input, dim,out=None) → Tensor

unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(0)
print(a.shape) #torch.Size([1, 768])
a = a.unsqueeze(2)
print(a.shape) #torch.Size([1, 768, 1])

也可以直接使用链式编程:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(1).unsqueeze(0)
print(a.shape) #torch.Size([1, 768, 1])

tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。

torch.squeeze(input, dim=None, out=None) → Tensor

squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。

同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
a=a.squeeze()
print(a)
print(a.shape) #torch.Size([2, 768])

pytorch下的unsqueeze和squeeze的用法说明

图片中的维度信息就不一样,红框中的括号层数不同。

注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。

a=torch.randn(2,768)
print(a.shape) #torch.Size([2, 768])
a=a.squeeze()
print(a.shape) #torch.Size([2, 768])

2、expand()

这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。

torch.Tensor.expand(*sizes) → Tensor

返回张量的一个新视图,可以将张量的单个维度扩大为更大的尺寸。

a=torch.randn(1,1,3,768)
print(a) 
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])

可以看到b和c的维度是一样的

pytorch下的unsqueeze和squeeze的用法说明

第0维由1变为2,可以看到就直接把原来的tensor在该维度上复制了一下。

3、repeat()

repeat(*sizes)

沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])

b表示对a的对应维度进行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])

c表示对a的对应维度进行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])

a:

pytorch下的unsqueeze和squeeze的用法说明

b

pytorch下的unsqueeze和squeeze的用法说明

c

pytorch下的unsqueeze和squeeze的用法说明

4、view()

tensor.view()这个函数有点类似reshape的功能,简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。例如:

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(8,6,768)
print(new_word_embedding.shape)

当然这里指定的维度的乘积一定要和原来的tensor的维度乘积相等,不然会报错的。16*3*768=8*6*768

另外当我们需要改变一个tensor的维度的时候,知道关键的维度,有不想手动的去计算其他的维度值,就可以使用view(-1),pytorch就会自动帮你计算出来。

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(1,-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(-1,768)
print(new_word_embedding.shape)

结果如下:使用-1以后,就会自动得到其他维度维。

pytorch下的unsqueeze和squeeze的用法说明

需要特别注意的是:view(-1,-1)这样的用法就会出错。也就是说view()函数中只能出现单个-1。

5、cat()

cat(seq,dim,out=None),表示把两个或者多个tensor拼接起来。

其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列

dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接

a=torch.randn(4,3)
b=torch.randn(4,3)
 
c=torch.cat((a,b),dim=0)#横向拼接,增加行 torch.Size([8, 3])
print(c.shape)
d=torch.cat((a,b),dim=1)#纵向拼接,增加列 torch.Size([4, 6])
print(d.shape)

还有一种写法:cat(list,dim,out=None),其中list中的元素为tensor。

tensors=[]
for i in range(10):
  tensors.append(torch.randn(4,3))
a=torch.cat(tensors,dim=0) #torch.Size([40, 3])
print(a.shape)
b=torch.cat(tensors,dim=1) #torch.Size([4, 30])
print(b.shape)

结果:

torch.Size([40, 3])
torch.Size([4, 30])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python数据结构之二叉树的统计与转换实例
Apr 29 Python
Python实现Linux下守护进程的编写方法
Aug 22 Python
tensorflow入门之训练简单的神经网络方法
Feb 26 Python
对Python 两大环境管理神器 pyenv 和 virtualenv详解
Dec 31 Python
selenium 安装与chromedriver安装的方法步骤
Jun 12 Python
对pyqt5之menu和action的使用详解
Jun 20 Python
Django之使用celery和NGINX生成静态页面实现性能优化
Oct 08 Python
使用tqdm显示Python代码执行进度功能
Dec 08 Python
Python多线程Threading、子线程与守护线程实例详解
Mar 24 Python
详解Python高阶函数
Aug 15 Python
关于pycharm 切换 python3.9 报错 ‘HTMLParser‘ object has no attribute ‘unescape‘ 的问题
Nov 24 Python
selenium+python自动化78-autoit参数化与批量上传功能的实现
Mar 04 Python
一文带你掌握Pyecharts地理数据可视化的方法
Feb 06 #Python
解决pycharm不能自动保存在远程linux中的问题
Feb 06 #Python
Python第三方库安装缓慢的解决方法
Feb 06 #Python
python中threading和queue库实现多线程编程
Feb 06 #Python
Python3爬虫ChromeDriver的安装实例
Feb 06 #Python
解决pycharm修改代码后第一次运行不生效的问题
Feb 06 #Python
Python tkinter之ComboBox(下拉框)的使用简介
Feb 05 #Python
You might like
php采集神器cURL使用方法详解
2016/02/19 PHP
一次因composer错误使用引发的问题与解决
2019/03/06 PHP
javascript生成/解析dom的CDATA类型的字段的代码
2007/04/22 Javascript
JavaScript入门教程(6) Window窗口对象
2009/01/31 Javascript
extJs 常用到的增,删,改,查操作代码
2009/12/28 Javascript
firefox浏览器不支持innerText的解决方法
2013/08/07 Javascript
jquery库或JS文件在eclipse下报错问题解决方法
2014/04/17 Javascript
jquery进行数组遍历如何跳出当前的each循环
2014/06/05 Javascript
超链接的禁用属性Disabled使用示例
2014/07/31 Javascript
js跨域请求的5中解决方式
2015/07/02 Javascript
jQuery实现元素拖拽并cookie保存顺序的方法
2016/02/20 Javascript
基于jquery实现三级下拉菜单
2016/05/10 Javascript
深入理解js generator数据类型
2016/08/16 Javascript
javascript简易画板开发
2020/04/12 Javascript
获取jqGrid中选择的行的数据
2016/11/30 Javascript
angularjs点击图片放大实现上传图片预览
2017/02/24 Javascript
Angular.js组件之input mask对input输入进行格式化详解
2017/07/10 Javascript
微信小程序模板和模块化用法实例分析
2017/11/28 Javascript
基于mpvue的小程序项目搭建的步骤
2018/05/22 Javascript
JavaScript arguments.callee作用及替换方案详解
2020/09/02 Javascript
python获取本地计算机名字的方法
2015/04/29 Python
python文件名和文件路径操作实例
2017/09/29 Python
python 日志增量抓取实现方法
2018/04/28 Python
Python实现针对json中某个关键字段进行排序操作示例
2018/12/25 Python
python针对Oracle常见查询操作实例分析
2020/04/30 Python
解决Django Haystack全文检索为空的问题
2020/05/19 Python
python 使用cycle构造无限循环迭代器
2020/12/02 Python
Michael Kors香港官网:美国奢侈品品牌
2019/12/26 全球购物
大专毕业自我鉴定
2014/02/04 职场文书
网页美工求职信
2014/02/15 职场文书
《故都的秋》教学反思
2014/04/15 职场文书
伊索寓言教学反思
2014/05/01 职场文书
乡镇党的群众路线教育实践活动剖析材料
2014/10/09 职场文书
公路施工安全责任书
2015/05/08 职场文书
八年级数学教学反思
2016/02/17 职场文书
PostgreSQL逻辑复制解密原理解析
2022/09/23 PostgreSQL