对pytorch中x = x.view(x.size(0), -1) 的理解说明


Posted in Python onMarch 03, 2021

在pytorch的CNN代码中经常会看到

x.view(x.size(0), -1)

首先,在pytorch中的view()函数就是用来改变tensor的形状的,例如将2行3列的tensor变为1行6列,其中-1表示会自适应的调整剩余的维度

a = torch.Tensor(2,3)
print(a)
# tensor([[0.0000, 0.0000, 0.0000],
#    [0.0000, 0.0000, 0.0000]])
 
print(a.view(1,-1))
# tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

在CNN中卷积或者池化之后需要连接全连接层,所以需要把多维度的tensor展平成一维,x.view(x.size(0), -1)就实现的这个功能

def forward(self,x):
  x=self.pre(x)
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  x=self.layer4(x)
    
  x=F.avg_pool2d(x,7)
  x=x.view(x.size(0),-1)
  return self.fc(x)

卷积或者池化之后的tensor的维度为(batchsize,channels,x,y),其中x.size(0)指batchsize的值,最后通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y),即将(channels,x,y)拉直,然后就可以和fc层连接了

补充:pytorch中view的用法(重构张量)

view在pytorch中是用来改变张量的shape的,简单又好用。

pytorch中view的用法通常是直接在张量名后用.view调用,然后放入自己想要的shape。如

tensor_name.view(shape)

Example:

1. 直接用法:

>>> x = torch.randn(4, 4)
 >>> x.size()
 torch.Size([4, 4])
 >>> y = x.view(16)
 >>> y.size()
 torch.Size([16])

2. 强调某一维度的尺寸:

>>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])

3. 拉直张量:

(直接填-1表示拉直, 等价于tensor_name.flatten())

>>> y = x.view(-1)
 >>> y.size()
 torch.Size([16])

4. 做维度变换时不改变内存排列

>>> a = torch.randn(1, 2, 3, 4)
>>> a.size()
torch.Size([1, 2, 3, 4])
>>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension
>>> b.size()
torch.Size([1, 3, 2, 4])
>>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory
>>> c.size()
torch.Size([1, 3, 2, 4])
>>> torch.equal(b, c)
False

注意最后的False,在张量b和c是不等价的。从这里我们可以看得出来,view函数如其名,只改变“看起来”的样子,不会改变张量在内存中的排列。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python程序员鲜为人知但你应该知道的17个问题
Jun 04 Python
在Python中使用base64模块处理字符编码的教程
Apr 28 Python
简单介绍Python中利用生成器实现的并发编程
May 04 Python
使用python3.5仿微软记事本notepad
Jun 15 Python
python线程、进程和协程详解
Jul 19 Python
一个基于flask的web应用诞生(1)
Apr 11 Python
opencv实现图片模糊和锐化操作
Nov 19 Python
在PyCharm中控制台输出日志分层级分颜色显示的方法
Jul 11 Python
pytorch打印网络结构的实例
Aug 19 Python
python3实现raspberry pi(树莓派)4驱小车控制程序
Feb 12 Python
Pytorch 图像变换函数集合小结
Feb 01 Python
Python OpenCV 彩色与灰度图像的转换实现
Jun 05 Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
You might like
php的memcached客户端memcached
2011/06/14 PHP
php 中英文语言转换类代码
2011/08/11 PHP
如何解决phpmyadmin导入数据库文件最大限制2048KB
2015/10/09 PHP
php 截取中英文混合字符串的方法
2018/05/31 PHP
js 操作select相关方法函数
2009/12/06 Javascript
jQuery 性能优化手册 推荐
2010/02/23 Javascript
JS禁用浏览器退格键实现思路及代码
2013/10/29 Javascript
浅谈javascript中字符串String与数组Array
2014/12/31 Javascript
微信小程序 Image API实例详解
2016/09/30 Javascript
JavaScript 继承详解(五)
2016/10/11 Javascript
jQuery简单获取DIV和A标签元素位置的方法
2017/02/07 Javascript
BootStrap的两种模态框方式
2017/05/10 Javascript
AngularJS实现动态添加Option的方法
2017/05/17 Javascript
Vue下的国际化处理方法
2017/12/18 Javascript
Nodejs中crypto模块的安全知识讲解
2018/01/03 NodeJs
node.js文件上传重命名以及移动位置的示例代码
2018/01/19 Javascript
解决vue keep-alive 数据更新的问题
2018/09/21 Javascript
Node.js path模块,获取文件后缀名操作
2020/11/07 Javascript
[45:14]Optic vs VP 2018国际邀请赛淘汰赛BO3 第二场 8.24
2018/08/25 DOTA
Python表示矩阵的方法分析
2017/05/26 Python
python中关于for循环的碎碎念
2017/06/30 Python
基于python实现蓝牙通信代码实例
2019/11/19 Python
python字典setdefault方法和get方法使用实例
2019/12/25 Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
2020/01/15 Python
Python 面向对象之类class和对象基本用法示例
2020/02/02 Python
带你学习Python如何实现回归树模型
2020/07/16 Python
matplotlib常见函数之plt.rcParams、matshow的使用(坐标轴设置)
2021/01/05 Python
css3和jquery实现的可折叠导航菜单适合放在手机网页的导航菜单
2014/09/02 HTML / CSS
纯CSS3实现漂亮的input输入框动画样式库(Text input love)
2018/12/29 HTML / CSS
CSS3 3D酷炫立方体变换动画的实现
2019/03/26 HTML / CSS
名词解释WEB SERVICE,SOAP,UDDI,WSDL,JAXP,JAXM;JSWDL开发包的介绍。
2012/10/27 面试题
什么是跨站脚本攻击
2014/12/11 面试题
大学生思想汇报范文
2013/12/31 职场文书
实验室标语
2014/06/21 职场文书
优秀团员事迹材料2000字
2014/08/20 职场文书
婚宴新郎致辞
2015/07/28 职场文书