pytorch 带batch的tensor类型图像显示操作


Posted in Python onMay 20, 2021

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

pytorch 带batch的tensor类型图像显示操作

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

pytorch 带batch的tensor类型图像显示操作

最终效果

pytorch 带batch的tensor类型图像显示操作

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

# 方法1:Image.show()
 # transforms.ToPILImage()中有一句
 # npimg = np.transpose(pic.numpy(), (1, 2, 0))
 # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
 img = transforms.ToPILImage(image[0])
 img.show()

 # 方法2:plt.imshow(ndarray)
 img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
 img = img.numpy() # FloatTensor转为ndarray
 img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
 # 显示图片
 plt.imshow(img)
 plt.show()
 cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中使用OpenCV进行人脸检测的例子
Apr 18 Python
python实现rest请求api示例
Apr 22 Python
Python应用03 使用PyQT制作视频播放器实例
Dec 07 Python
Python字符串处理实例详解
May 18 Python
python3.x 将byte转成字符串的方法
Jul 17 Python
python 对多个csv文件分别进行处理的方法
Jan 07 Python
Python 中list ,set,dict的大规模查找效率对比详解
Oct 11 Python
Python list与NumPy array 区分详解
Nov 06 Python
Python之Django自动实现html代码(下拉框,数据选择)
Mar 13 Python
常用的10个Python实用小技巧
Aug 10 Python
详解pycharm自动import所需的库的操作方法
Nov 30 Python
Python更改pip镜像源的方法示例
Dec 01 Python
pytorch 中nn.Dropout的使用说明
May 20 #Python
Python 线程池模块之多线程操作代码
May 20 #Python
pytorch中[..., 0]的用法说明
May 20 #Python
浅谈pytorch中stack和cat的及to_tensor的坑
May 20 #Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
You might like
最令PHP初学者头痛的十四个问题
2006/07/12 PHP
PHP语法自动检查的Vim插件
2014/08/11 PHP
PHP响应post请求上传文件的方法
2015/12/17 PHP
PHP实现随机生成水印图片功能
2017/03/22 PHP
PHP addslashes()函数讲解
2019/02/03 PHP
javascript 写类方式之九
2009/07/05 Javascript
Extjs Ext.MessageBox.confirm 确认对话框详解
2010/04/02 Javascript
跟我学习javascript的严格模式
2015/11/16 Javascript
基于JS实现密码框(password)中显示文字提示功能代码
2016/05/27 Javascript
JS实现弹出居中的模式窗口示例
2016/06/20 Javascript
input 禁止输入特殊字符的四种实现方式
2016/08/24 Javascript
js获取当前页的URL与window.location.href简单方法
2017/02/13 Javascript
Vue Ajax跨域请求实例详解
2017/06/20 Javascript
Vue的实例、生命周期与Vue脚手架(vue-cli)实例详解
2017/12/27 Javascript
微信小程序block的使用教程
2018/04/01 Javascript
vue使用v-if v-show页面闪烁,div闪现的解决方法
2018/10/12 Javascript
Python使用arrow库优雅地处理时间数据详解
2017/10/10 Python
Python实现的井字棋(Tic Tac Toe)游戏示例
2018/01/31 Python
python实现二维数组的对角线遍历
2019/03/02 Python
Python中的引用知识点总结
2019/05/20 Python
python实现键盘输入的实操方法
2019/07/16 Python
详解pandas使用drop_duplicates去除DataFrame重复项参数
2019/08/01 Python
python 字典 setdefault()和get()方法比较详解
2019/08/07 Python
决策树剪枝算法的python实现方法详解
2019/09/18 Python
如何提高python 中for循环的效率
2020/04/15 Python
基于python实现检索标记敏感词并输出
2020/05/07 Python
jupyter notebook更换皮肤主题的实现
2021/01/07 Python
selenium如何定位span元素的实现
2021/01/13 Python
localstorage和sessionstorage使用记录(推荐)
2017/05/23 HTML / CSS
马来西亚时装购物网站:ZALORA马来西亚
2017/03/14 全球购物
Sofft鞋官网:世界知名鞋类品牌
2017/03/28 全球购物
学校消防安全制度
2014/01/30 职场文书
篮球拉拉队口号
2015/12/25 职场文书
学习弘扬焦裕禄精神心得体会
2016/01/23 职场文书
2016年优秀教师先进事迹材料
2016/02/26 职场文书
Go中使用gjson来操作JSON数据的实现
2022/08/14 Golang