对pytorch中x = x.view(x.size(0), -1) 的理解说明


Posted in Python onMarch 03, 2021

在pytorch的CNN代码中经常会看到

x.view(x.size(0), -1)

首先,在pytorch中的view()函数就是用来改变tensor的形状的,例如将2行3列的tensor变为1行6列,其中-1表示会自适应的调整剩余的维度

a = torch.Tensor(2,3)
print(a)
# tensor([[0.0000, 0.0000, 0.0000],
#    [0.0000, 0.0000, 0.0000]])
 
print(a.view(1,-1))
# tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

在CNN中卷积或者池化之后需要连接全连接层,所以需要把多维度的tensor展平成一维,x.view(x.size(0), -1)就实现的这个功能

def forward(self,x):
  x=self.pre(x)
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  x=self.layer4(x)
    
  x=F.avg_pool2d(x,7)
  x=x.view(x.size(0),-1)
  return self.fc(x)

卷积或者池化之后的tensor的维度为(batchsize,channels,x,y),其中x.size(0)指batchsize的值,最后通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y),即将(channels,x,y)拉直,然后就可以和fc层连接了

补充:pytorch中view的用法(重构张量)

view在pytorch中是用来改变张量的shape的,简单又好用。

pytorch中view的用法通常是直接在张量名后用.view调用,然后放入自己想要的shape。如

tensor_name.view(shape)

Example:

1. 直接用法:

>>> x = torch.randn(4, 4)
 >>> x.size()
 torch.Size([4, 4])
 >>> y = x.view(16)
 >>> y.size()
 torch.Size([16])

2. 强调某一维度的尺寸:

>>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])

3. 拉直张量:

(直接填-1表示拉直, 等价于tensor_name.flatten())

>>> y = x.view(-1)
 >>> y.size()
 torch.Size([16])

4. 做维度变换时不改变内存排列

>>> a = torch.randn(1, 2, 3, 4)
>>> a.size()
torch.Size([1, 2, 3, 4])
>>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension
>>> b.size()
torch.Size([1, 3, 2, 4])
>>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory
>>> c.size()
torch.Size([1, 3, 2, 4])
>>> torch.equal(b, c)
False

注意最后的False,在张量b和c是不等价的。从这里我们可以看得出来,view函数如其名,只改变“看起来”的样子,不会改变张量在内存中的排列。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python学习笔记之解析json的方法分析
Apr 21 Python
使用pandas模块读取csv文件和excel表格,并用matplotlib画图的方法
Jun 22 Python
Python离线安装PIL 模块的方法
Jan 08 Python
python取余运算符知识点详解
Jun 27 Python
python如何将两个txt文件内容合并
Oct 18 Python
如何在python中实现随机选择
Nov 02 Python
开启Django博客的RSS功能的实现方法
Feb 17 Python
Python爬虫程序架构和运行流程原理解析
Mar 09 Python
keras.layer.input()用法说明
Jun 16 Python
flask项目集成swagger的方法
Dec 09 Python
利用Python判断你的密码难度等级
Jun 02 Python
OpenCV-Python实现轮廓的特征值
Jun 09 Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
You might like
MVC模式的PHP实现
2006/10/09 PHP
php array的学习笔记
2012/05/16 PHP
解析php file_exists无效的解决办法
2013/06/26 PHP
PHP中iconv函数知识汇总
2015/07/02 PHP
基于php实现随机合并数组并排序(原排序)
2015/11/26 PHP
PHP 年月日的三级联动实例代码
2017/05/24 PHP
微信公众平台开发教程④ ThinkPHP框架下微信支付功能图文详解
2019/04/10 PHP
php源码的使用方法讲解
2019/09/26 PHP
对Jquery中的ajax再封装,简化操作示例
2014/02/12 Javascript
js 判断浏览器使用的语言示例代码
2014/03/22 Javascript
javascript常用的正则表达式实例
2014/05/15 Javascript
javascript使用window.open提示“已经计划系统关机”的原因
2014/08/15 Javascript
JS遍历数组及打印数组实例分析
2016/01/21 Javascript
使用AngularJS 跨站请求如何解决jsonp请求问题
2017/01/16 Javascript
什么是Vue.js框架 为什么选择它?
2017/10/17 Javascript
浅析从vue源码看观察者模式
2018/01/29 Javascript
jQuery解析json格式数据示例
2018/09/01 jQuery
简单分析js中的this的原理
2019/08/31 Javascript
微信小程序进入广告实现代码实例
2019/09/19 Javascript
小程序按钮避免多次调用接口和点击方案实现(不用showLoading)
2020/04/15 Javascript
Python中Django框架利用url来控制登录的方法
2015/07/25 Python
Python实用技巧之利用元组代替字典并为元组元素命名
2018/07/11 Python
python pyheatmap包绘制热力图
2018/11/09 Python
python hashlib加密实现代码
2019/10/17 Python
详解python opencv、scikit-image和PIL图像处理库比较
2019/12/26 Python
Django 自定义404 500等错误页面的实现
2020/03/08 Python
python 实现压缩和解压缩的示例
2020/09/22 Python
基于python实现监听Rabbitmq系统日志代码示例
2020/11/28 Python
详解使用python爬取抖音app视频(appium可以操控手机)
2021/01/26 Python
Yahoo-PHP面试题4
2012/05/05 面试题
经典c++面试题二
2015/08/14 面试题
倡议书格式范文
2014/04/14 职场文书
公司合作协议书范本
2014/04/18 职场文书
2014年人事科工作总结
2014/11/19 职场文书
如何用JavaScript实现一个数组惰性求值库
2021/05/05 Javascript
基于Redis过期事件实现订单超时取消
2021/05/08 Redis