浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用判断语句和循环的教程
Apr 25 Python
python 性能提升的几种方法
Jul 15 Python
Python实现读取并保存文件的类
May 11 Python
Python实现基于PIL和tesseract的验证码识别功能示例
Jul 11 Python
python读取目录下最新的文件夹方法
Dec 24 Python
Python 类的私有属性和私有方法实例分析
Sep 29 Python
使用 Python 读取电子表格中的数据实例详解
Apr 17 Python
基于python实现对文件进行切分行
Apr 26 Python
Pandas中DataFrame基本函数整理(小结)
Jul 20 Python
PyTorch 导数应用的使用教程
Aug 31 Python
Python3 pyecharts生成Html文件柱状图及折线图代码实例
Sep 29 Python
解决jupyter notebook图片显示模糊和保存清晰图片的操作
Apr 24 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
将OICQ数据转成MYSQL数据
2006/10/09 PHP
php连接mysql数据库代码
2009/03/10 PHP
用php实现选择排序的解决方法
2013/05/04 PHP
php在linux下检测mysql同步状态的方法
2015/01/15 PHP
php 防止表单重复提交两种实现方法
2016/11/03 PHP
php读取XML的常见方法实例总结
2017/04/25 PHP
关于PHP中interface的用处详解
2020/07/26 PHP
CSDN轮换广告图片轮换效果
2007/03/27 Javascript
js解析与序列化json数据(三)json的解析探讨
2013/02/01 Javascript
js实现Form栏显示全格式时间时钟效果代码
2015/08/19 Javascript
继续学习javascript闭包
2015/12/03 Javascript
浅谈js之字面量、对象字面量的访问、关键字in的用法
2016/11/20 Javascript
javascript基础知识之html5轮播图实例讲解(44)
2017/02/17 Javascript
JavaScript中Promise的使用详解
2017/02/26 Javascript
Vue.js实现多条件筛选、搜索、排序及分页的表格功能
2020/11/24 Javascript
angular仿支付宝密码框输入效果
2017/03/25 Javascript
ajax +NodeJS 实现图片上传实例
2017/06/06 NodeJs
Vue.js实现网格列表布局转换方法
2017/08/25 Javascript
vue.js实现点击后动态添加class及删除同级class的实现代码
2018/04/04 Javascript
JavaScript引用类型Date常见用法实例分析
2018/08/08 Javascript
python实现发送邮件功能
2017/07/22 Python
Python微信库:itchat的用法详解
2017/08/14 Python
python调用百度REST API实现语音识别
2018/08/30 Python
一篇文章弄懂Python中所有数组数据类型
2019/06/23 Python
Pandas DataFrame数据的更改、插入新增的列和行的方法
2019/06/25 Python
Python 使用xlwt模块将多行多列数据循环写入excel文档的操作
2020/11/10 Python
CSS3 3D制作实战案例分析
2016/09/18 HTML / CSS
html5实现微信打飞机游戏
2014/03/27 HTML / CSS
KIKO MILANO英国官网:意大利知名化妆品和护肤品品牌
2017/09/25 全球购物
Probikekit欧盟:在线公路自行车专家
2019/07/12 全球购物
Linux不知道文件后缀名怎么判断文件类型
2012/04/26 面试题
技术股东合作协议书
2014/12/02 职场文书
单位实习鉴定评语
2015/01/04 职场文书
初中信息技术教学计划
2015/01/22 职场文书
python 自动刷新网页的两种方法
2021/04/20 Python
Vue实现跑马灯样式文字横向滚动
2021/11/23 Vue.js