pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python运用于数据分析的简单教程
Mar 27 Python
python实现linux下使用xcopy的方法
Jun 28 Python
Python标准库之itertools库的使用方法
Sep 07 Python
Tensorflow的可视化工具Tensorboard的初步使用详解
Feb 11 Python
python实现远程通过网络邮件控制计算机重启或关机
Feb 22 Python
Python代码块批量添加Tab缩进的方法
Jun 25 Python
Python3.5常见内置方法参数用法实例详解
Apr 29 Python
Django集成搜索引擎Elasticserach的方法示例
Jun 04 Python
python 字典访问的三种方法小结
Dec 05 Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 Python
Python切片列表字符串如何实现切换
Aug 06 Python
Python自动巡检H3C交换机实现过程解析
Aug 14 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
用PHP动态创建Flash动画
2006/10/09 PHP
网页游戏开发入门教程三(简单程序应用)
2009/11/02 PHP
php在linux下检测mysql同步状态的方法
2015/01/15 PHP
PHP基于双向链表与排序操作实现的会员排名功能示例
2017/12/26 PHP
js实现可拖动DIV的方法
2013/12/17 Javascript
深入理解JavaScript系列(33):设计模式之策略模式详解
2015/03/03 Javascript
学习使用AngularJS文件上传控件
2016/02/16 Javascript
jQuery 3.0中存在问题及解决办法
2016/07/15 Javascript
同步文本框内容JS代码实现
2016/08/04 Javascript
jQuery内容选择器与表单选择器实例分析
2019/06/28 jQuery
微信小程序使用蓝牙小插件
2019/09/23 Javascript
基于vue+uniapp直播项目实现uni-app仿抖音/陌陌直播室功能
2019/11/12 Javascript
在Vue项目中使用Typescript的实现
2019/12/19 Javascript
jquery添加div实现消息聊天框
2020/02/08 jQuery
javascript实现电商放大镜效果
2020/11/23 Javascript
详解Python中最难理解的点-装饰器
2017/04/03 Python
Python键盘输入转换为列表的实例
2018/06/23 Python
python pyheatmap包绘制热力图
2018/11/09 Python
pandas的qcut()方法详解
2019/07/06 Python
Python pandas实现excel工作表合并功能详解
2019/08/29 Python
python文件操作的简单方法总结
2019/11/07 Python
Jupyter notebook设置背景主题,字体大小及自动补全代码的操作
2020/04/13 Python
Python unittest基本使用方法代码实例
2020/06/29 Python
CSS3 特效范例整理
2011/08/22 HTML / CSS
h5实现获取用户地理定位的实例代码
2017/07/17 HTML / CSS
微信小程序“圣诞帽”的实现思路详解
2017/12/28 HTML / CSS
工程师自我评价怎么写
2013/09/19 职场文书
自荐信要包含哪些内容
2013/11/06 职场文书
日语专业毕业生求职信
2013/12/04 职场文书
中层竞聘演讲稿
2014/01/09 职场文书
幼儿园安全责任书
2014/04/14 职场文书
秘书英文求职信
2014/04/16 职场文书
食品质量与安全专业毕业生求职信
2014/08/11 职场文书
关于旅游的活动方案
2014/08/15 职场文书
2015年小学远程教育工作总结
2015/07/28 职场文书
MySQL transaction事务安全示例讲解
2022/06/21 MySQL