浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
windows及linux环境下永久修改pip镜像源的方法
Nov 28 Python
python计算两个地址之间的距离方法
Jun 09 Python
Python SQL查询并生成json文件操作示例
Aug 17 Python
Python Matplotlib实现三维数据的散点图绘制
Mar 19 Python
pyqt5 实现多窗口跳转的方法
Jun 19 Python
python 实现的发送邮件模板【普通邮件、带附件、带图片邮件】
Jul 06 Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 Python
Python Pillow.Image 图像保存和参数选择方式
Jan 09 Python
PyCharm 2020 激活到 2100 年的教程
Mar 25 Python
Python 的 __str__ 和 __repr__ 方法对比
Sep 02 Python
Python 列表推导式需要注意的地方
Oct 23 Python
Python开发五子棋小游戏
May 02 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
咖啡豆分级制度 咖啡豆等级分类 咖啡豆是按口感分类的吗?
2021/03/05 新手入门
php检测iis环境是否支持htaccess的方法
2014/02/18 PHP
thinkPHP框架自动填充原理与用法分析
2018/04/03 PHP
详解php中curl返回false的解决办法
2019/03/18 PHP
IE中radio 或checkbox的checked属性初始状态下不能选中显示问题
2009/07/25 Javascript
jquery.validate使用攻略 第二部
2010/07/01 Javascript
用js实现table单元格高宽调整,兼容合并单元格(兼容IE6、7、8、FF)实例
2013/06/25 Javascript
javascript自定义startWith()和endWith()的两种方法
2013/11/11 Javascript
js 上下左右键控制焦点(示例代码)
2013/12/14 Javascript
AngularJs Injecting Services Into Controllers详解
2016/09/02 Javascript
BootStrap框架个人总结(bootstrap框架、导航条、下拉菜单、轮播广告carousel、栅格系统布局、标签页tabs、模态框、菜单定位)
2016/12/01 Javascript
js学习总结之DOM2兼容处理重复问题的解决方法
2017/07/27 Javascript
解决Vue 浏览器后退无法触发beforeRouteLeave的问题
2017/12/24 Javascript
通过一次报错详细谈谈Point事件
2018/05/17 Javascript
vue 监听某个div垂直滚动条下拉到底部的方法
2018/09/15 Javascript
vue中多个倒计时实现代码实例
2019/03/27 Javascript
微信小程序 组件的外部样式externalClasses使用详解
2019/09/06 Javascript
浅谈vue3中effect与computed的亲密关系
2019/10/10 Javascript
微信小程序点击顶部导航栏切换样式代码实例
2019/11/12 Javascript
微信小程序转化为uni-app项目的方法示例
2020/05/22 Javascript
[07:09]DOTA2-DPC中国联赛 正赛 Ehome vs Elephant 选手采访
2021/03/11 DOTA
Python获取服务器信息的最简单实现方法
2015/03/05 Python
python学习之面向对象【入门初级篇】
2017/01/21 Python
pandas数据清洗,排序,索引设置,数据选取方法
2018/05/18 Python
python3+requests接口自动化session操作方法
2018/10/13 Python
Python3爬虫学习之爬虫利器Beautiful Soup用法分析
2018/12/12 Python
python求最大值,不使用内置函数的实现方法
2019/07/09 Python
Python使用tkinter模块实现推箱子游戏
2019/10/08 Python
python3代码中实现加法重载的实例
2020/12/03 Python
python 基于selenium实现鼠标拖拽功能
2020/12/24 Python
详解基于 Canvas 手撸一个六边形能力图
2019/09/02 HTML / CSS
美国专注于健康商品的网站:eVitamins
2017/01/23 全球购物
美国名牌太阳镜折扣网站:Eyedictive
2017/05/15 全球购物
十岁生日家长答谢词
2014/01/17 职场文书
简历中自我评价范文
2015/03/11 职场文书
七个Python必备的GUI库
2021/04/27 Python