浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中self原理实例分析
Apr 30 Python
在Python中操作时间之mktime()方法的使用教程
May 22 Python
Ruby使用eventmachine为HTTP服务器添加文件下载功能
Apr 20 Python
python、java等哪一门编程语言适合人工智能?
Nov 13 Python
Python3.x爬虫下载网页图片的实例讲解
May 22 Python
Python3实现取图片中特定的像素替换指定的颜色示例
Jan 24 Python
PyQT5 emit 和 connect的用法详解
Dec 13 Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 Python
Python多线程实现支付模拟请求过程解析
Apr 21 Python
python报错TypeError: ‘NoneType‘ object is not subscriptable的解决方法
Nov 05 Python
Python语言规范之Pylint的详细用法
Jun 24 Python
用Python可视化新冠疫情数据
Jan 18 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
咖啡知识 除了喝咖啡还有那些知识点
2021/03/06 新手入门
PHP两种去掉数组重复值的方法比较
2014/06/19 PHP
php实现的用户查询类实例
2015/06/18 PHP
laravel orm 关联条件查询代码
2019/10/21 PHP
关于js注册事件的常用方法
2013/04/03 Javascript
JS格式化数字金额用逗号隔开保留两位小数
2013/10/18 Javascript
js返回前一页刷新本页重载页面
2014/07/29 Javascript
修复jQuery tablesorter无法正确排序的bug(加千分位数字后)
2016/03/30 Javascript
探寻JavaScript中this指针指向
2016/04/23 Javascript
Vue.js教程之计算属性
2016/11/11 Javascript
ES6中Generator与异步操作实例分析
2017/03/31 Javascript
react-router JS 控制路由跳转实例
2017/06/15 Javascript
AngularJS路由删除#符号解决的办法
2017/09/28 Javascript
jQuery实现html双向绑定功能示例
2017/10/09 jQuery
利用Angular2 + Ionic3开发IOS应用实例教程
2018/01/15 Javascript
详解如何使用nvm管理Node.js多版本
2019/05/06 Javascript
vue解决使用$http获取数据时报错的问题
2019/10/30 Javascript
OpenLayers加载缩放控件使用方法详解
2020/09/25 Javascript
Python备份Mysql脚本
2008/08/11 Python
使用setup.py安装python包和卸载python包的方法
2013/11/27 Python
使用python批量修改文件名的方法(视频合并时)
2020/03/24 Python
Python图像处理库PIL中图像格式转换的实现
2020/02/26 Python
python实现邮件循环自动发件功能
2020/09/11 Python
Python通过fnmatch模块实现文件名匹配
2020/09/30 Python
联想加拿大官方网站:Lenovo Canada
2018/04/05 全球购物
澳大利亚先进的皮肤和激光诊所购物网站:Soho Skincare
2018/10/15 全球购物
美国渔具店:FishUSA
2019/08/07 全球购物
杭州时比特电子有限公司SQL
2013/08/22 面试题
介绍一下write命令
2012/09/24 面试题
七年级英语教学反思
2014/01/15 职场文书
教师节倡议书
2014/08/30 职场文书
工作所在部门证明
2014/09/21 职场文书
社区班子个人对照检查材料思想汇报
2014/10/07 职场文书
父亲节寄语大全
2015/02/27 职场文书
Python采集股票数据并制作可视化柱状图
2022/04/04 Python
Python sklearn分类决策树方法详解
2022/09/23 Python