对pytorch中x = x.view(x.size(0), -1) 的理解说明


Posted in Python onMarch 03, 2021

在pytorch的CNN代码中经常会看到

x.view(x.size(0), -1)

首先,在pytorch中的view()函数就是用来改变tensor的形状的,例如将2行3列的tensor变为1行6列,其中-1表示会自适应的调整剩余的维度

a = torch.Tensor(2,3)
print(a)
# tensor([[0.0000, 0.0000, 0.0000],
#    [0.0000, 0.0000, 0.0000]])
 
print(a.view(1,-1))
# tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

在CNN中卷积或者池化之后需要连接全连接层,所以需要把多维度的tensor展平成一维,x.view(x.size(0), -1)就实现的这个功能

def forward(self,x):
  x=self.pre(x)
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  x=self.layer4(x)
    
  x=F.avg_pool2d(x,7)
  x=x.view(x.size(0),-1)
  return self.fc(x)

卷积或者池化之后的tensor的维度为(batchsize,channels,x,y),其中x.size(0)指batchsize的值,最后通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y),即将(channels,x,y)拉直,然后就可以和fc层连接了

补充:pytorch中view的用法(重构张量)

view在pytorch中是用来改变张量的shape的,简单又好用。

pytorch中view的用法通常是直接在张量名后用.view调用,然后放入自己想要的shape。如

tensor_name.view(shape)

Example:

1. 直接用法:

>>> x = torch.randn(4, 4)
 >>> x.size()
 torch.Size([4, 4])
 >>> y = x.view(16)
 >>> y.size()
 torch.Size([16])

2. 强调某一维度的尺寸:

>>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])

3. 拉直张量:

(直接填-1表示拉直, 等价于tensor_name.flatten())

>>> y = x.view(-1)
 >>> y.size()
 torch.Size([16])

4. 做维度变换时不改变内存排列

>>> a = torch.randn(1, 2, 3, 4)
>>> a.size()
torch.Size([1, 2, 3, 4])
>>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension
>>> b.size()
torch.Size([1, 3, 2, 4])
>>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory
>>> c.size()
torch.Size([1, 3, 2, 4])
>>> torch.equal(b, c)
False

注意最后的False,在张量b和c是不等价的。从这里我们可以看得出来,view函数如其名,只改变“看起来”的样子,不会改变张量在内存中的排列。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python中的Matplotlib模块入门教程
Apr 15 Python
python删除特定文件的方法
Jul 30 Python
Python语法快速入门指南
Oct 12 Python
go和python变量赋值遇到的一个问题
Aug 31 Python
使用python爬取B站千万级数据
Jun 08 Python
Python操作SQLite数据库过程解析
Sep 02 Python
为什么说Python可以实现所有的算法
Oct 04 Python
Django 简单实现分页与搜索功能的示例代码
Nov 07 Python
解决tensorflow打印tensor有省略号的问题
Feb 04 Python
django haystack实现全文检索的示例代码
Jun 24 Python
拿来就用!Python批量合并PDF的示例代码
Aug 10 Python
Python源码解析之List
May 21 Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
You might like
一贴学会PHP 新手入门教程
2009/08/03 PHP
ThinkPHP中的三大自动简介
2014/08/22 PHP
PHP中使用hidef扩展代替define提高性能
2015/04/09 PHP
关于WordPress的SEO优化相关的一些PHP页面脚本技巧
2015/12/10 PHP
PHP实现的redis主从数据库状态检测功能示例
2017/07/20 PHP
Laravel框架路由和控制器的绑定操作方法
2018/06/12 PHP
javascript之dhDataGrid Ver2.0.0代码
2007/07/01 Javascript
JQuery 网站换肤功能实现代码
2009/11/02 Javascript
浅析用prototype定义自己的方法
2013/11/14 Javascript
JS拖拽组件学习使用
2016/01/19 Javascript
详解javascript实现瀑布流列式布局
2016/01/29 Javascript
JS实现兼容各种浏览器的高级拖动方法完整实例【测试可用】
2016/06/21 Javascript
详解vue数据渲染出现闪烁问题
2017/06/29 Javascript
在一个页面实现两个zTree联动的方法
2017/12/20 Javascript
Vue+mui实现图片的本地缓存示例代码
2018/05/24 Javascript
浅谈Vue路由快照实现思路及其问题
2018/06/07 Javascript
开发Node CLI构建微信小程序脚手架的示例
2020/03/27 Javascript
python写xml文件的操作实例
2014/10/05 Python
Python画图学习入门教程
2016/07/01 Python
Python3实现的字典遍历操作详解
2018/04/18 Python
Python实现端口检测的方法
2018/07/24 Python
Django forms表单 select下拉框的传值实例
2019/07/19 Python
Django密码系统实现过程详解
2019/07/19 Python
利用python-pypcap抓取带VLAN标签的数据包方法
2019/07/23 Python
对Django url的几种使用方式详解
2019/08/06 Python
python判断自身是否正在运行的方法
2019/08/08 Python
Lands’ End官网:经典的美国生活方式品牌
2016/08/14 全球购物
节省高达65%的城市景点费用:Go City
2019/07/06 全球购物
精彩的大学生自我评价
2013/11/17 职场文书
会计系毕业求职信
2014/08/07 职场文书
2014年人事专员工作总结
2014/11/19 职场文书
学生检讨书
2015/01/27 职场文书
食堂采购员岗位职责
2015/04/03 职场文书
学生病假条范文
2015/08/17 职场文书
2016年小学感恩节活动总结
2016/04/01 职场文书
ubuntu端向日葵键盘输入卡顿问题及解决
2022/12/24 Servers