pytorch 带batch的tensor类型图像显示操作


Posted in Python onMay 20, 2021

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

pytorch 带batch的tensor类型图像显示操作

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

pytorch 带batch的tensor类型图像显示操作

最终效果

pytorch 带batch的tensor类型图像显示操作

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

# 方法1:Image.show()
 # transforms.ToPILImage()中有一句
 # npimg = np.transpose(pic.numpy(), (1, 2, 0))
 # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
 img = transforms.ToPILImage(image[0])
 img.show()

 # 方法2:plt.imshow(ndarray)
 img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
 img = img.numpy() # FloatTensor转为ndarray
 img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
 # 显示图片
 plt.imshow(img)
 plt.show()
 cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux系统使用python获取cpu信息脚本分享
Jan 15 Python
VScode编写第一个Python程序HelloWorld步骤
Apr 06 Python
pandas.dataframe按行索引表达式选取方法
Oct 30 Python
Python语言进阶知识点总结
May 28 Python
python 使用socket传输图片视频等文件的实现方式
Aug 07 Python
python实现在多维数组中挑选符合条件的全部元素
Nov 26 Python
Python3监控windows,linux系统的CPU、硬盘、内存使用率和各个端口的开启情况详细代码实例
Mar 18 Python
Python实现汇率转换操作
May 03 Python
python设置表格边框的具体方法
Jul 17 Python
Python实现AES加密,解密的两种方法
Oct 03 Python
Python使用random模块实现掷骰子游戏的示例代码
Apr 29 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
May 28 Python
pytorch 中nn.Dropout的使用说明
May 20 #Python
Python 线程池模块之多线程操作代码
May 20 #Python
pytorch中[..., 0]的用法说明
May 20 #Python
浅谈pytorch中stack和cat的及to_tensor的坑
May 20 #Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
You might like
php上传文件的增强函数
2010/07/21 PHP
php格式化金额函数分享
2015/02/02 PHP
PHP实现的解汉诺塔问题算法示例
2018/08/06 PHP
PHP PDOStatement::getColumnMeta讲解
2019/02/01 PHP
jQuery 选择器理解
2010/03/16 Javascript
基于jquery的下拉框改变动态添加和删除表格实现代码
2020/09/12 Javascript
js 回车提交表单两种实现方法
2012/12/31 Javascript
JS匀速运动演示示例代码
2013/11/26 Javascript
关于页面嵌入swf覆盖div层的问题的解决方法
2014/02/11 Javascript
JavaScript DOM操作表格及样式
2015/04/13 Javascript
javascript之Array 数组对象详解
2016/06/07 Javascript
AngularJS操作键值对象类似java的hashmap(填坑小结)
2016/11/12 Javascript
详解微信小程序——自定义圆形进度条
2016/12/29 Javascript
easyui关于validatebox实现多重规则验证的方法(必看)
2017/04/12 Javascript
JavaScript实现选中文字提示新浪微博分享效果
2017/06/15 Javascript
JavaScript编写的网页小游戏,很给力
2017/08/18 Javascript
解决在vue+webpack开发中出现两个或多个菜单公用一个组件问题
2017/11/28 Javascript
vue中改变选中当前项的显示隐藏或者状态的实现方法
2018/02/08 Javascript
Vue的双向数据绑定实现原理解析
2020/02/17 Javascript
Python多线程和队列操作实例
2015/06/21 Python
pytorch掉坑记录:model.eval的作用说明
2020/06/23 Python
python实现定时发送邮件到指定邮箱
2020/12/23 Python
python tkinter实现下载进度条及抖音视频去水印原理
2021/02/07 Python
canvas如何绘制钟表的方法
2017/12/13 HTML / CSS
Jimmy Choo美国官网:周仰杰鞋子品牌
2018/06/08 全球购物
Desigual美国官方网站:西班牙服装品牌
2019/03/29 全球购物
音乐学院硕士生的自我评价分享
2013/11/01 职场文书
给全校老师的建议书
2014/03/13 职场文书
公务员保密承诺书
2014/03/27 职场文书
做一个有道德的人活动实施方案
2014/08/23 职场文书
党员个人对照检查材料
2014/10/01 职场文书
店面出租协议书范本
2014/11/28 职场文书
英文升职感谢信
2015/01/23 职场文书
Python实现文本文件拆分写入到多个文本文件的方法
2021/04/18 Python
JavaScript小技巧带你提升你的代码技能
2021/09/15 Javascript
关于k8s环境部署mysql主从的问题
2022/03/13 MySQL