浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python创建日历实例
Aug 21 Python
python基础教程之Hello World!
Aug 29 Python
python开发之基于thread线程搜索本地文件的方法
Nov 11 Python
JPype实现在python中调用JAVA的实例
Jul 19 Python
Python SQLite3简介
Feb 22 Python
python代码过长的换行方法
Jul 19 Python
pycharm使用matplotlib.pyplot不显示图形的解决方法
Oct 28 Python
Python父目录、子目录的相互调用方法
Feb 16 Python
python3转换code128条形码的方法
Apr 17 Python
python rsync服务器之间文件夹同步脚本
Aug 29 Python
python+opencv边缘提取与各函数参数解析
Mar 09 Python
教你使用一行Python代码玩遍童年的小游戏
Aug 23 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
如何做到多笔资料的同步
2006/10/09 PHP
php 广告调用类代码(支持Flash调用)
2011/08/11 PHP
PHP中去掉字符串首尾空格的方法
2012/05/19 PHP
PHP统计nginx访问日志中的搜索引擎抓取404链接页面路径
2014/06/30 PHP
phpstorm编辑器乱码问题解决
2014/12/01 PHP
php计算两个坐标(经度,纬度)之间距离的方法
2015/04/17 PHP
WordPress中用于更新伪静态规则的PHP代码实例讲解
2015/12/18 PHP
PHP数组生成XML格式数据的封装类实例
2016/11/10 PHP
PHP Mysqli 常用代码集合
2016/11/12 PHP
JS类库Bindows1.3中的内存释放方式分析
2007/03/08 Javascript
from 表单提交返回值用post或者是get方法实现
2013/08/21 Javascript
JQuery性能优化的几点建议
2014/05/14 Javascript
Node.js的包详细介绍
2015/01/14 Javascript
详解JavaScript中setSeconds()方法的使用
2015/06/11 Javascript
jQuery实现定时读取分析xml文件的方法
2015/07/16 Javascript
window.setInterval()方法的定义和用法及offsetLeft与style.left的区别
2015/11/11 Javascript
JavaScript中this的用法实例分析
2016/12/19 Javascript
微信小程序实现复选框效果
2018/12/28 Javascript
vue 移动端注入骨架屏的配置方法
2019/06/25 Javascript
Vue-cli项目部署到Nginx服务器的方法
2019/11/01 Javascript
解决vue单页面应用进入页面加载所有 js 的问题
2020/08/12 Javascript
React倒计时功能实现代码——解耦通用
2020/09/18 Javascript
django 利用Q对象与F对象进行查询的实现
2020/05/15 Python
python 对一幅灰度图像进行直方图均衡化
2020/10/27 Python
Python 求向量的余弦值操作
2021/03/04 Python
pytorch 计算Parameter和FLOP的操作
2021/03/04 Python
amazeui时间组件的实现示例
2020/08/18 HTML / CSS
MCAKE蛋糕官方网站:一直都是巴黎的味道
2018/02/06 全球购物
《望洞庭》教学反思
2014/02/16 职场文书
小组合作学习反思
2014/02/18 职场文书
大专生找工作自荐书
2014/06/10 职场文书
违反单位工作制度检讨书
2014/10/25 职场文书
自愿离婚协议书范本
2015/01/26 职场文书
勤俭节约倡议书范文
2015/04/29 职场文书
tensorflow+k-means聚类简单实现猫狗图像分类的方法
2021/04/28 Python
vue 自定义的组件绑定点击事件
2022/04/21 Vue.js