pytorch 带batch的tensor类型图像显示操作


Posted in Python onMay 20, 2021

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

pytorch 带batch的tensor类型图像显示操作

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

pytorch 带batch的tensor类型图像显示操作

最终效果

pytorch 带batch的tensor类型图像显示操作

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

# 方法1:Image.show()
 # transforms.ToPILImage()中有一句
 # npimg = np.transpose(pic.numpy(), (1, 2, 0))
 # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
 img = transforms.ToPILImage(image[0])
 img.show()

 # 方法2:plt.imshow(ndarray)
 img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
 img = img.numpy() # FloatTensor转为ndarray
 img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
 # 显示图片
 plt.imshow(img)
 plt.show()
 cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
采用python实现简单QQ单用户机器人的方法
Jul 03 Python
python实现将汉字转换成汉语拼音的库
May 05 Python
python简单商城购物车实例代码
Mar 15 Python
python实现寻找最长回文子序列的方法
Jun 02 Python
Python对象与引用的介绍
Jan 24 Python
PyQt5实现五子棋游戏(人机对弈)
Mar 24 Python
详解python使用turtle库来画一朵花
Mar 21 Python
深入了解和应用Python 装饰器 @decorator
Apr 02 Python
python程序运行进程、使用时间、剩余时间显示功能的实现代码
Jul 11 Python
Anaconda和ipython环境适配的实现
Apr 22 Python
python读取xml文件方法解析
Aug 04 Python
Elasticsearch 批量操作
Apr 19 Python
pytorch 中nn.Dropout的使用说明
May 20 #Python
Python 线程池模块之多线程操作代码
May 20 #Python
pytorch中[..., 0]的用法说明
May 20 #Python
浅谈pytorch中stack和cat的及to_tensor的坑
May 20 #Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
You might like
windows xp下安装pear
2006/12/02 PHP
MySQL相关说明
2007/01/15 PHP
使用PHP socke 向指定页面提交数据
2008/07/23 PHP
laravel5表单唯一验证的实例代码
2019/09/30 PHP
用javascript获取textarea中的光标位置
2008/05/06 Javascript
Jquery下的26个实用小技巧(jQuery tips, tricks & solutions)
2010/03/01 Javascript
jQuery focus和blur事件的应用详解
2014/01/26 Javascript
javascript实现十六进制颜色值(HEX)和RGB格式相互转换
2014/06/20 Javascript
javascript刷新父页面的各种方法汇总
2014/09/03 Javascript
Jquery api 速查表分享
2015/01/12 Javascript
JavaScript实现为指定对象添加多个事件处理程序的方法
2015/04/17 Javascript
js禁止页面刷新与后退的方法
2015/06/08 Javascript
jQuery实现的简单折叠菜单(折叠面板)效果代码
2015/09/16 Javascript
javascript实现别踩白块儿小游戏程序
2015/11/22 Javascript
ajax 提交数据到后台jsp页面及页面跳转问题
2017/01/19 Javascript
JavaScript 中Date对象的格式化代码方法汇总
2017/09/06 Javascript
十个免费的web前端开发工具详细整理
2017/09/18 Javascript
详解Angular调试技巧之报错404(not found)
2018/01/31 Javascript
vue实现商品加减计算总价的实例代码
2018/08/12 Javascript
JavaScript格式化json和xml的方法示例
2019/01/22 Javascript
JS中如何轻松遍历对象属性的方式总结
2019/08/06 Javascript
详解webpack打包vue项目之后生成的dist文件该怎么启动运行
2019/09/06 Javascript
python cx_Oracle模块的安装和使用详细介绍
2017/02/13 Python
python存储16bit和32bit图像的实例
2018/12/05 Python
Python 多线程搜索txt文件的内容,并写入搜到的内容(Lock)方法
2019/08/23 Python
Python并发请求下限制QPS(每秒查询率)的实现代码
2020/06/05 Python
PyCharm2020.1.2社区版安装,配置及使用教程详解(Windows)
2020/08/07 Python
CSS3 实现倒计时效果
2020/11/25 HTML / CSS
阿里旅行:飞猪
2017/01/05 全球购物
举例说明类变量和实例变量的区别
2016/06/30 面试题
公司道歉信范文
2014/01/09 职场文书
应届生自荐书
2014/06/23 职场文书
旅游活动总结
2014/08/27 职场文书
解除劳动关系协议书2篇
2014/11/28 职场文书
企业投资意向书
2015/05/09 职场文书
SQL优化老出错,那是你没弄明白MySQL解释计划用法
2021/11/27 MySQL