pytorch 带batch的tensor类型图像显示操作


Posted in Python onMay 20, 2021

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

pytorch 带batch的tensor类型图像显示操作

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

pytorch 带batch的tensor类型图像显示操作

最终效果

pytorch 带batch的tensor类型图像显示操作

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

# 方法1:Image.show()
 # transforms.ToPILImage()中有一句
 # npimg = np.transpose(pic.numpy(), (1, 2, 0))
 # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
 img = transforms.ToPILImage(image[0])
 img.show()

 # 方法2:plt.imshow(ndarray)
 img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
 img = img.numpy() # FloatTensor转为ndarray
 img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
 # 显示图片
 plt.imshow(img)
 plt.show()
 cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python+Socket实现基于UDP协议的局域网广播功能示例
Aug 31 Python
python实现随机梯度下降(SGD)
Mar 24 Python
Python3导入CSV文件的实例(跟Python2有些许的不同)
Jun 22 Python
Win10下python3.5和python2.7环境变量配置教程
Sep 18 Python
Django给admin添加Action的步骤详解
May 01 Python
在pyqt5中QLineEdit里面的内容回车发送的实例
Jun 21 Python
python实现动态数组的示例代码
Jul 15 Python
python爬取王者荣耀全皮肤的简单实现代码
Jan 31 Python
Python中remove漏删和索引越界问题的解决
Mar 18 Python
jupyter notebook 的工作空间设置操作
Apr 20 Python
解决Keras使用GPU资源耗尽的问题
Jun 22 Python
python 用opencv实现图像修复和图像金字塔
Nov 27 Python
pytorch 中nn.Dropout的使用说明
May 20 #Python
Python 线程池模块之多线程操作代码
May 20 #Python
pytorch中[..., 0]的用法说明
May 20 #Python
浅谈pytorch中stack和cat的及to_tensor的坑
May 20 #Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
You might like
ThinkPHP自动验证失败的解决方法
2011/06/09 PHP
php.ini-dist 和 php.ini-recommended 的区别介绍(方便开发与安全的朋友)
2012/07/01 PHP
一个漂亮的php验证码类(分享)
2013/08/06 PHP
php支付宝接口用法分析
2015/01/04 PHP
PHP5.3以上版本安装ZendOptimizer扩展
2015/03/27 PHP
javascript 类型判断代码分析
2010/03/28 Javascript
javascript计时器事件使用详解
2014/01/07 Javascript
Jquery中offset()和position()的区别分析
2015/02/05 Javascript
JavaScript中constructor()方法的使用简介
2015/06/05 Javascript
JS实现弹性漂浮效果的广告代码
2015/09/02 Javascript
js创建数组的简单方法
2016/07/27 Javascript
jQuery用noConflict代替$的实现方法
2017/04/12 jQuery
JavaScript使用readAsDataURL读取图像文件
2017/05/10 Javascript
JavaScript闭包的简单应用
2017/09/01 Javascript
jQuery实现input输入框获取焦点与失去焦点时提示的消失与显示功能示例
2019/05/27 jQuery
JavaScript实现网页计算器功能
2020/10/29 Javascript
Python 元类使用说明
2009/12/18 Python
在Python的Flask框架中使用模版的入门教程
2015/04/20 Python
Python中%r和%s的详解及区别
2017/03/16 Python
python字典DICT类型合并详解
2017/08/17 Python
Python使用Matplotlib实现Logos设计代码
2017/12/25 Python
python 3利用Dlib 19.7实现摄像头人脸检测特征点标定
2018/02/26 Python
Python实现的简单计算器功能详解
2018/08/25 Python
解决python中os.listdir()函数读取文件夹下文件的乱序和排序问题
2018/10/17 Python
Python获取数据库数据并保存在excel表格中的方法
2019/06/12 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
2019/08/17 Python
Django单元测试中Fixtures的使用方法
2020/02/26 Python
Python callable内置函数原理解析
2020/03/05 Python
Python 为什么推荐蛇形命名法原因浅析
2020/06/18 Python
CSS3实现多背景模拟动态边框的效果
2016/11/08 HTML / CSS
CSS3动画特效在活动页中的应用
2020/01/21 HTML / CSS
美国婴儿用品店:Babies”R”Us
2017/10/12 全球购物
印度在线购买电子产品网站:Croma
2020/01/02 全球购物
法国购买二手电子产品网站:Asgoodasnew
2020/03/27 全球购物
python状态机transitions库详解
2021/06/02 Python
使用HBuilder制作一个简单的HTML5网页
2022/07/07 HTML / CSS