pytorch 带batch的tensor类型图像显示操作


Posted in Python onMay 20, 2021

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

pytorch 带batch的tensor类型图像显示操作

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

pytorch 带batch的tensor类型图像显示操作

最终效果

pytorch 带batch的tensor类型图像显示操作

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

# 方法1:Image.show()
 # transforms.ToPILImage()中有一句
 # npimg = np.transpose(pic.numpy(), (1, 2, 0))
 # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
 img = transforms.ToPILImage(image[0])
 img.show()

 # 方法2:plt.imshow(ndarray)
 img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
 img = img.numpy() # FloatTensor转为ndarray
 img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
 # 显示图片
 plt.imshow(img)
 plt.show()
 cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的strftime()方法的使用
May 22 Python
编写Python脚本抓取网络小说来制作自己的阅读器
Aug 20 Python
Python处理Excel文件实例代码
Jun 20 Python
python将字典内容存入mysql实例代码
Jan 18 Python
pandas数值计算与排序方法
Apr 12 Python
Python高级特性切片(Slice)操作详解
Sep 27 Python
python高级特性和高阶函数及使用详解
Oct 17 Python
PyQt5实现简易电子词典
Jun 25 Python
django框架forms组件用法实例详解
Dec 10 Python
Django中modelform组件实例用法总结
Feb 10 Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 Python
python实现mean-shift聚类算法
Jun 10 Python
pytorch 中nn.Dropout的使用说明
May 20 #Python
Python 线程池模块之多线程操作代码
May 20 #Python
pytorch中[..., 0]的用法说明
May 20 #Python
浅谈pytorch中stack和cat的及to_tensor的坑
May 20 #Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
You might like
星际争霸 Starcraft 编年史
2020/03/14 星际争霸
php通过排列组合实现1到9数字相加都等于20的方法
2015/08/03 PHP
jquery 页面全选框实践代码
2010/04/02 Javascript
关于jQuery新的事件绑定机制on()的使用技巧
2013/04/26 Javascript
禁止iframe脚本弹出的窗口覆盖了父窗口的方法
2014/09/06 Javascript
Javascript中replace()小结
2015/09/30 Javascript
原生javascript实现图片无缝滚动效果
2016/02/12 Javascript
浅析jquery如何判断滚动条滚到页面底部并执行事件
2016/04/29 Javascript
HTML5canvas 绘制一个圆环形的进度表示实例
2016/12/16 Javascript
Java中int与integer的区别(基本数据类型与引用数据类型)
2017/02/19 Javascript
vue+vuex+axios+echarts画一个动态更新的中国地图的方法
2017/12/19 Javascript
详解javascript 正则表达式之分组与前瞻匹配
2018/05/30 Javascript
JavaScript设计模式之缓存代理模式原理与简单用法示例
2018/08/07 Javascript
JS回调函数深入理解
2019/10/16 Javascript
JavaScript 中的六种循环方法
2021/01/06 Javascript
[00:34]TI7不朽珍藏III——纯金地穴编织者饰品展示
2017/07/15 DOTA
python打开网页和暂停实例
2014/09/30 Python
Windows下Python使用Pandas模块操作Excel文件的教程
2016/05/31 Python
Python实现的文本简单可逆加密算法示例
2017/05/18 Python
python实现验证码识别功能
2018/06/07 Python
Python实现端口检测的方法
2018/07/24 Python
python使用xlsxwriter实现有向无环图到Excel的转换
2018/12/12 Python
Python2与Python3的区别实例分析
2019/04/11 Python
win8.1安装Python 2.7版环境图文详解
2019/07/01 Python
python调用函数、类和文件操作简单实例总结
2019/11/29 Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
2020/01/15 Python
Python排序函数的使用方法详解
2020/12/11 Python
详解使用python爬取抖音app视频(appium可以操控手机)
2021/01/26 Python
美国旅游网站:Tours4Fun
2017/02/17 全球购物
聪明的粉丝购买门票的地方:TickPick
2018/03/09 全球购物
一些PHP的面试题
2015/05/06 面试题
项目合作协议书
2014/09/23 职场文书
意外伤害赔偿协议书范本
2014/09/28 职场文书
天那边观后感
2015/06/09 职场文书
拿破仑传读书笔记
2015/07/01 职场文书
python中的被动信息搜集
2021/04/29 Python