pytorch下的unsqueeze和squeeze的用法说明


Posted in Python onFebruary 06, 2021

#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉

#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加

import torch 
a=torch.rand(2,3,1)       
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1]) 
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze().size())    #torch.Size([2, 3]) 
print(a.squeeze(0).size())   #torch.Size([2, 3, 1])
 
print(a.squeeze(-1).size())   #torch.Size([2, 3])
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze(-2).size())   #torch.Size([2, 3, 1])
print(a.squeeze(-3).size())   #torch.Size([2, 3, 1])
print(a.squeeze(1).size())   #torch.Size([2, 3, 1])
print(a.squeeze(2).size())   #torch.Size([2, 3])
print(a.squeeze(3).size())   #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
 
print(a.unsqueeze().size())   #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size())  #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size())  #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size())    #torch.Size([2, 3])

补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结

学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。

1、unsqueeze()和squeeze()

torch.unsqueeze(input, dim,out=None) → Tensor

unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(0)
print(a.shape) #torch.Size([1, 768])
a = a.unsqueeze(2)
print(a.shape) #torch.Size([1, 768, 1])

也可以直接使用链式编程:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(1).unsqueeze(0)
print(a.shape) #torch.Size([1, 768, 1])

tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。

torch.squeeze(input, dim=None, out=None) → Tensor

squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。

同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
a=a.squeeze()
print(a)
print(a.shape) #torch.Size([2, 768])

pytorch下的unsqueeze和squeeze的用法说明

图片中的维度信息就不一样,红框中的括号层数不同。

注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。

a=torch.randn(2,768)
print(a.shape) #torch.Size([2, 768])
a=a.squeeze()
print(a.shape) #torch.Size([2, 768])

2、expand()

这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。

torch.Tensor.expand(*sizes) → Tensor

返回张量的一个新视图,可以将张量的单个维度扩大为更大的尺寸。

a=torch.randn(1,1,3,768)
print(a) 
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])

可以看到b和c的维度是一样的

pytorch下的unsqueeze和squeeze的用法说明

第0维由1变为2,可以看到就直接把原来的tensor在该维度上复制了一下。

3、repeat()

repeat(*sizes)

沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])

b表示对a的对应维度进行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])

c表示对a的对应维度进行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])

a:

pytorch下的unsqueeze和squeeze的用法说明

b

pytorch下的unsqueeze和squeeze的用法说明

c

pytorch下的unsqueeze和squeeze的用法说明

4、view()

tensor.view()这个函数有点类似reshape的功能,简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。例如:

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(8,6,768)
print(new_word_embedding.shape)

当然这里指定的维度的乘积一定要和原来的tensor的维度乘积相等,不然会报错的。16*3*768=8*6*768

另外当我们需要改变一个tensor的维度的时候,知道关键的维度,有不想手动的去计算其他的维度值,就可以使用view(-1),pytorch就会自动帮你计算出来。

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(1,-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(-1,768)
print(new_word_embedding.shape)

结果如下:使用-1以后,就会自动得到其他维度维。

pytorch下的unsqueeze和squeeze的用法说明

需要特别注意的是:view(-1,-1)这样的用法就会出错。也就是说view()函数中只能出现单个-1。

5、cat()

cat(seq,dim,out=None),表示把两个或者多个tensor拼接起来。

其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列

dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接

a=torch.randn(4,3)
b=torch.randn(4,3)
 
c=torch.cat((a,b),dim=0)#横向拼接,增加行 torch.Size([8, 3])
print(c.shape)
d=torch.cat((a,b),dim=1)#纵向拼接,增加列 torch.Size([4, 6])
print(d.shape)

还有一种写法:cat(list,dim,out=None),其中list中的元素为tensor。

tensors=[]
for i in range(10):
  tensors.append(torch.randn(4,3))
a=torch.cat(tensors,dim=0) #torch.Size([40, 3])
print(a.shape)
b=torch.cat(tensors,dim=1) #torch.Size([4, 30])
print(b.shape)

结果:

torch.Size([40, 3])
torch.Size([4, 30])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python面向对象编程中的类和对象学习教程
Mar 30 Python
Python实现在线程里运行scrapy的方法
Apr 07 Python
Python中生成器和yield语句的用法详解
Apr 17 Python
Python中规范定义命名空间的一些建议
Jun 04 Python
python实现FTP服务器服务的方法
Apr 11 Python
Python发送http请求解析返回json的实例
Mar 26 Python
Selenium鼠标与键盘事件常用操作方法示例
Aug 13 Python
Python中的类与类型示例详解
Jul 10 Python
python标记语句块使用方法总结
Aug 05 Python
Pytorch中index_select() 函数的实现理解
Nov 19 Python
python tqdm实现进度条的示例代码
Nov 10 Python
教你如何使用Python下载B站视频的详细教程
Apr 29 Python
一文带你掌握Pyecharts地理数据可视化的方法
Feb 06 #Python
解决pycharm不能自动保存在远程linux中的问题
Feb 06 #Python
Python第三方库安装缓慢的解决方法
Feb 06 #Python
python中threading和queue库实现多线程编程
Feb 06 #Python
Python3爬虫ChromeDriver的安装实例
Feb 06 #Python
解决pycharm修改代码后第一次运行不生效的问题
Feb 06 #Python
Python tkinter之ComboBox(下拉框)的使用简介
Feb 05 #Python
You might like
造就帕卡马拉的帕卡斯是怎么被发现的
2021/03/03 咖啡文化
通过5个php实例细致说明传值与传引用的区别
2012/08/08 PHP
10条php编程小技巧
2015/07/07 PHP
php版微信自动登录并获取昵称的方法
2016/09/23 PHP
PHP仿微信发红包领红包效果
2016/10/30 PHP
from 表单提交返回值用post或者是get方法实现
2013/08/21 Javascript
jquery ajax应用中iframe自适应高度问题解决方法
2014/04/12 Javascript
jquery获取对象的方法足以应付常见的各种类型的对象
2014/05/14 Javascript
node.js中RPC(远程过程调用)的实现原理介绍
2014/12/05 Javascript
分享五个有用的jquery小技巧
2015/10/08 Javascript
手机端实现Bootstrap简单图片轮播效果
2016/10/13 Javascript
AngularJs用户登录问题处理(交互及验证、阻止FQ处理)
2017/10/26 Javascript
three.js实现3D影院的原理的代码分析
2017/12/18 Javascript
Vuex 入门教程
2018/01/10 Javascript
vue打包相关细节整理(小结)
2018/09/28 Javascript
Vue的路由及路由钩子函数的实现
2019/07/02 Javascript
用Node写一条配置环境的指令
2019/11/14 Javascript
ES6字符串的扩展实例
2020/12/21 Javascript
[00:53]2015国际邀请赛 中国区预选赛一触即发
2015/05/14 DOTA
python数据分析数据标准化及离散化详解
2018/02/26 Python
Python中文件的读取和写入操作
2018/04/27 Python
Pandas实现一列数据分隔为两列
2020/05/18 Python
python 高阶函数简单介绍
2021/02/19 Python
pytho matplotlib工具栏源码探析一之禁用工具栏、默认工具栏和工具栏管理器三种模式的差异
2021/02/25 Python
关于PySnooper 永远不要使用print进行调试的问题
2021/03/04 Python
英国奢侈皮具品牌:Aspinal of London
2018/09/02 全球购物
培训自我鉴定
2014/01/31 职场文书
市场营销专业大学生职业生涯规划文
2014/03/06 职场文书
交通事故私了协议书
2014/04/16 职场文书
500字小学生检讨书
2015/02/19 职场文书
考试没考好检讨书
2015/05/06 职场文书
《狼牙山五壮士》读后感:宁死不屈,视死如归
2019/08/16 职场文书
用Python selenium实现淘宝抢单机器人
2021/06/18 Python
SpringBoot实现quartz定时任务可视化管理功能
2021/08/30 Java/Android
Vue+TypeScript中处理computed方式
2022/04/02 Vue.js
MySQL串行化隔离级别(间隙锁实现)
2022/06/16 MySQL