浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中urllib2模块的8个使用细节分享
Jan 01 Python
用Python解决计数原理问题的方法
Aug 04 Python
Python程序中设置HTTP代理
Nov 06 Python
Python实现Mysql数据库连接池实例详解
Apr 11 Python
pyhton列表转换为数组的实例
Apr 04 Python
Python给定一个句子倒序输出单词以及字母的方法
Dec 20 Python
Python基于滑动平均思想实现缺失数据填充的方法
Feb 21 Python
详解Python 定时框架 Apscheduler原理及安装过程
Jun 14 Python
python web框架中实现原生分页
Sep 08 Python
Python数据存储之 h5py详解
Dec 26 Python
关于python中导入文件到list的问题
Oct 31 Python
python数据可视化使用pyfinance分析证券收益示例详解
Nov 20 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
支持php4、php5的mysql数据库操作类
2008/01/10 PHP
使用Limit参数优化MySQL查询的方法
2008/11/12 PHP
PHP PDO函数库详解
2010/04/27 PHP
Yii实现MySQL多数据库和读写分离实例分析
2014/12/03 PHP
使用PHP访问RabbitMQ消息队列的方法示例
2018/06/06 PHP
Javascript 对象的解释
2008/11/24 Javascript
js利用数组length属性清空和截短数组的小例子
2014/01/15 Javascript
jQuery学习笔记之基础中的基础
2015/01/19 Javascript
leaflet的开发入门教程
2016/11/17 Javascript
node使用UEditor富文本编辑器的方法实例
2017/07/11 Javascript
分享vue.js devtools遇到一系列问题
2017/10/24 Javascript
jquery如何实现点击空白处隐藏元素
2017/12/05 jQuery
bing Map 在vue项目中的使用详解
2018/04/09 Javascript
Babel 入门教程学习笔记
2018/06/13 Javascript
JS中的const命令你真懂它吗
2020/03/08 Javascript
微信小程序实现电子签名功能
2020/07/29 Javascript
[48:21]Mski vs VGJ.S Supermajor小组赛C组 BO3 第一场 6.3
2018/06/04 DOTA
Python selenium 父子、兄弟、相邻节点定位方式详解
2016/09/15 Python
Python的argparse库使用详解
2018/10/09 Python
python使用xlsxwriter实现有向无环图到Excel的转换
2018/12/12 Python
Appium Python自动化测试之环境搭建的步骤
2019/01/23 Python
使用python绘制温度变化雷达图
2019/10/18 Python
基于Python采集爬取微信公众号历史数据
2020/11/27 Python
美国羊皮公司:Overland
2018/01/15 全球购物
英国手机壳购买网站:Case Hut
2019/04/11 全球购物
夏威夷咖啡公司:Hawaii Coffee Company
2019/09/19 全球购物
俄罗斯童装网上商店:BebaKids
2020/06/06 全球购物
提高EJB性能都有哪些技巧
2012/03/25 面试题
监理员的岗位职责
2013/11/13 职场文书
庆七一活动方案
2014/01/25 职场文书
社区党的群众路线教育实践活动剖析材料
2014/10/09 职场文书
亮剑精神观后感
2015/06/05 职场文书
2016年全国爱牙日宣传活动总结
2016/04/05 职场文书
2019班干部竞选演讲稿范本!
2019/07/08 职场文书
Python使用海龟绘图实现贪吃蛇游戏
2021/06/18 Python
Mysql表数据比较大情况下修改添加字段的方法实例
2022/06/28 MySQL