pytorch下的unsqueeze和squeeze的用法说明


Posted in Python onFebruary 06, 2021

#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉

#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加

import torch 
a=torch.rand(2,3,1)       
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1]) 
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze().size())    #torch.Size([2, 3]) 
print(a.squeeze(0).size())   #torch.Size([2, 3, 1])
 
print(a.squeeze(-1).size())   #torch.Size([2, 3])
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze(-2).size())   #torch.Size([2, 3, 1])
print(a.squeeze(-3).size())   #torch.Size([2, 3, 1])
print(a.squeeze(1).size())   #torch.Size([2, 3, 1])
print(a.squeeze(2).size())   #torch.Size([2, 3])
print(a.squeeze(3).size())   #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
 
print(a.unsqueeze().size())   #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size())  #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size())  #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size())    #torch.Size([2, 3])

补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结

学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。

1、unsqueeze()和squeeze()

torch.unsqueeze(input, dim,out=None) → Tensor

unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(0)
print(a.shape) #torch.Size([1, 768])
a = a.unsqueeze(2)
print(a.shape) #torch.Size([1, 768, 1])

也可以直接使用链式编程:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(1).unsqueeze(0)
print(a.shape) #torch.Size([1, 768, 1])

tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。

torch.squeeze(input, dim=None, out=None) → Tensor

squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。

同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
a=a.squeeze()
print(a)
print(a.shape) #torch.Size([2, 768])

pytorch下的unsqueeze和squeeze的用法说明

图片中的维度信息就不一样,红框中的括号层数不同。

注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。

a=torch.randn(2,768)
print(a.shape) #torch.Size([2, 768])
a=a.squeeze()
print(a.shape) #torch.Size([2, 768])

2、expand()

这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。

torch.Tensor.expand(*sizes) → Tensor

返回张量的一个新视图,可以将张量的单个维度扩大为更大的尺寸。

a=torch.randn(1,1,3,768)
print(a) 
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])

可以看到b和c的维度是一样的

pytorch下的unsqueeze和squeeze的用法说明

第0维由1变为2,可以看到就直接把原来的tensor在该维度上复制了一下。

3、repeat()

repeat(*sizes)

沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])

b表示对a的对应维度进行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])

c表示对a的对应维度进行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])

a:

pytorch下的unsqueeze和squeeze的用法说明

b

pytorch下的unsqueeze和squeeze的用法说明

c

pytorch下的unsqueeze和squeeze的用法说明

4、view()

tensor.view()这个函数有点类似reshape的功能,简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。例如:

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(8,6,768)
print(new_word_embedding.shape)

当然这里指定的维度的乘积一定要和原来的tensor的维度乘积相等,不然会报错的。16*3*768=8*6*768

另外当我们需要改变一个tensor的维度的时候,知道关键的维度,有不想手动的去计算其他的维度值,就可以使用view(-1),pytorch就会自动帮你计算出来。

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(1,-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(-1,768)
print(new_word_embedding.shape)

结果如下:使用-1以后,就会自动得到其他维度维。

pytorch下的unsqueeze和squeeze的用法说明

需要特别注意的是:view(-1,-1)这样的用法就会出错。也就是说view()函数中只能出现单个-1。

5、cat()

cat(seq,dim,out=None),表示把两个或者多个tensor拼接起来。

其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列

dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接

a=torch.randn(4,3)
b=torch.randn(4,3)
 
c=torch.cat((a,b),dim=0)#横向拼接,增加行 torch.Size([8, 3])
print(c.shape)
d=torch.cat((a,b),dim=1)#纵向拼接,增加列 torch.Size([4, 6])
print(d.shape)

还有一种写法:cat(list,dim,out=None),其中list中的元素为tensor。

tensors=[]
for i in range(10):
  tensors.append(torch.randn(4,3))
a=torch.cat(tensors,dim=0) #torch.Size([40, 3])
print(a.shape)
b=torch.cat(tensors,dim=1) #torch.Size([4, 30])
print(b.shape)

结果:

torch.Size([40, 3])
torch.Size([4, 30])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
在Python编程过程中用单元测试法调试代码的介绍
Apr 02 Python
Python随手笔记第一篇(2)之初识列表和元组
Jan 23 Python
Python实现图片尺寸缩放脚本
Mar 10 Python
Django重装mysql后启动报错:No module named ‘MySQLdb’的解决方法
Apr 22 Python
Python3单行定义多个变量或赋值方法
Jul 12 Python
python实现信号时域统计特征提取代码
Feb 26 Python
Python字符串hashlib加密模块使用案例
Mar 10 Python
解决jupyter notebook打不开无反应 浏览器未启动的问题
Apr 10 Python
Python通过kerberos安全认证操作kafka方式
Jun 06 Python
Python wordcloud库安装方法总结
Dec 31 Python
python process模块的使用简介
May 14 Python
聊聊Python String型列表求最值的问题
Jan 18 Python
一文带你掌握Pyecharts地理数据可视化的方法
Feb 06 #Python
解决pycharm不能自动保存在远程linux中的问题
Feb 06 #Python
Python第三方库安装缓慢的解决方法
Feb 06 #Python
python中threading和queue库实现多线程编程
Feb 06 #Python
Python3爬虫ChromeDriver的安装实例
Feb 06 #Python
解决pycharm修改代码后第一次运行不生效的问题
Feb 06 #Python
Python tkinter之ComboBox(下拉框)的使用简介
Feb 05 #Python
You might like
PHP Google的translate API代码
2008/12/10 PHP
PHP中nowdoc和heredoc使用需要注意的一点
2014/03/21 PHP
php中get_object_vars()方法用法实例
2015/02/08 PHP
解决PHP上传非标准格式的图片pjpeg失败的方法
2017/03/12 PHP
Yii2框架实现登录、退出及自动登录功能的方法详解
2017/10/24 PHP
javascript innerHTML使用分析
2010/12/03 Javascript
jquery二级导航内容均分的原理及实现
2013/08/13 Javascript
js 3种归并操作的实例代码
2013/10/30 Javascript
nodejs下打包模块archiver详解
2014/12/03 NodeJs
jQuery实现无限往下滚动效果代码
2016/04/16 Javascript
JavaScript中return用法示例
2016/11/29 Javascript
JS组件系列之JS组件封装过程详解
2017/04/28 Javascript
vue.js使用代理和使用Nginx来解决跨域的问题
2018/02/03 Javascript
vue-cli初始化项目中使用less的方法
2018/08/09 Javascript
jquery实现动态改变css样式的方法分析
2019/05/27 jQuery
JS常见面试试题总结【去重、遍历、闭包、继承等】
2019/08/27 Javascript
layui问题之自动滚动二级iframe页面到指定位置的方法
2019/09/18 Javascript
vue子组件改变父组件传递的prop值通过sync实现数据双向绑定(DEMO)
2020/02/01 Javascript
js简单实现自动生成表格功能示例
2020/06/02 Javascript
Python深入学习之装饰器
2014/08/31 Python
python网络编程之数据传输UDP实例分析
2015/05/20 Python
Python实现简单字典树的方法
2016/04/29 Python
python多进程使用及线程池的使用方法代码详解
2018/10/24 Python
python生成并处理uuid的实现方式
2020/03/03 Python
python 实现任务管理清单案例
2020/04/25 Python
浅谈HTML5 FileReader分布读取文件以及其方法简介
2017/11/09 HTML / CSS
英国第二大营养品供应商:Vitabiotics
2016/10/01 全球购物
Java的五个基础面试题
2016/02/26 面试题
经典c++面试题六
2012/01/18 面试题
大学生自荐书范文
2013/12/10 职场文书
《雨霖铃》教学反思
2014/02/22 职场文书
读书伴我成长演讲稿
2014/05/07 职场文书
我的中国梦演讲稿500字
2014/08/19 职场文书
2016见义勇为事迹材料汇总
2016/03/01 职场文书
导游词之天津盘山
2019/11/01 职场文书
mybatis中注解与xml配置的对应关系和对比分析
2021/08/04 Java/Android