浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3 socket同步通信简单示例
Jun 07 Python
python 反向输出字符串的方法
Jul 16 Python
基于tensorflow加载部分层的方法
Jul 26 Python
Python 实现中值滤波、均值滤波的方法
Jan 09 Python
Python解决pip install时出现的Could not fetch URL问题
Aug 01 Python
详解Python self 参数
Aug 30 Python
python编写俄罗斯方块
Mar 13 Python
python IDLE添加行号显示教程
Apr 25 Python
python 对象真假值的实例(哪些视为False)
Dec 11 Python
Appium+Python实现简单的自动化登录测试的实现
Jan 26 Python
Python与C/C++的相互调用案例
Mar 04 Python
Python自动操作神器PyAutoGUI的使用教程
Jun 16 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
第六节--访问属性和方法
2006/11/16 PHP
php中instanceof 与 is_a()区别分析
2015/03/03 PHP
php简单统计字符串单词数量的方法
2015/06/19 PHP
PHP 设计模式系列之 specification规格模式
2016/01/10 PHP
php删除数组指定元素实现代码
2017/05/03 PHP
PHP如何使用JWT做Api接口身份认证的实现
2020/02/03 PHP
jQuery 剧场版 你必须知道的javascript
2009/05/27 Javascript
js保存当前路径(cookies记录)
2010/12/14 Javascript
Js从头学起(基本数据类型和引用类型的参数传递详细分析)
2012/02/16 Javascript
使用focus方法让光标默认停留在INPUT框
2014/07/29 Javascript
学习使用AngularJS文件上传控件
2016/02/16 Javascript
基于jquery实现ajax无刷新评论
2020/08/19 Javascript
详解为Angular.js内置$http服务添加拦截器的方法
2016/12/20 Javascript
微信小程序实战之上拉(分页加载)效果(2)
2017/04/17 Javascript
bootstrap选项卡扩展功能详解
2017/06/14 Javascript
深入理解vue.js中的v-if和v-show
2017/06/22 Javascript
基于vue实现网站前台的权限管理(前后端分离实践)
2018/01/13 Javascript
JS基于封装函数实现的表格分页完整示例
2018/06/26 Javascript
微信小程序实现带参数的分享功能(两种方法)
2019/05/17 Javascript
解决Layui中layer报错的问题
2019/09/03 Javascript
angular8和ngrx8结合使用的步骤介绍
2019/12/01 Javascript
Vue.extend 登录注册模态框的实现
2020/12/29 Vue.js
详解Python的Django框架中inclusion_tag的使用
2015/07/21 Python
不知道这5种下划线的含义,你就不算真的会Python!
2018/10/09 Python
python设置环境变量的作用整理
2020/02/17 Python
python实现3D地图可视化
2020/03/25 Python
Python小白不正确的使用类变量实例
2020/05/29 Python
python os.rename实例用法详解
2020/12/06 Python
美国最大的高尔夫发球时间预订网站:TeeOff.com
2018/03/28 全球购物
C语言开发工程师测试题
2016/12/20 面试题
建筑工地质量标语
2014/06/12 职场文书
代办社保委托书范文
2014/10/06 职场文书
优秀教师事迹材料
2014/12/15 职场文书
2015年初三班主任工作总结
2015/05/21 职场文书
创业计划书之电动车企业
2019/10/11 职场文书
python process模块的使用简介
2021/05/14 Python