pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python提取html文件中的特定数据的实现代码
Mar 24 Python
python通过post提交数据的方法
May 06 Python
如何将python中的List转化成dictionary
Aug 15 Python
Python对字符串实现去重操作的方法示例
Aug 11 Python
详解python的ORM中Pony用法
Feb 09 Python
python基础梳理(一)(推荐)
Apr 06 Python
详解如何在cmd命令窗口中搭建简单的python开发环境
Aug 29 Python
python绘制BA无标度网络示例代码
Nov 21 Python
Python turtle库绘制菱形的3种方式小结
Nov 23 Python
Python 元组拆包示例(Tuple Unpacking)
Dec 24 Python
Python Selenium实现无可视化界面过程解析
Aug 25 Python
python四种出行路线规划的实现
Jun 23 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
解析用PHP读写音频文件信息的详解(支持WMA和MP3)
2013/05/10 PHP
整理一些JavaScript的IE和火狐的兼容性注意事项
2011/03/17 Javascript
jquery动画2.元素坐标动画效果(创建一个图片走廊)
2012/08/24 Javascript
使用Chrome调试JavaScript的断点设置和调试技巧
2014/12/16 Javascript
jquery中$each()方法的使用指南
2015/04/30 Javascript
JavaScript 节流函数 Throttle 详解
2016/07/04 Javascript
AngularJs 国际化(I18n/L10n)详解
2016/09/01 Javascript
vue.js将unix时间戳转换为自定义时间格式
2017/01/03 Javascript
Vue实例中生命周期created和mounted的区别详解
2017/08/25 Javascript
jQuery实现html双向绑定功能示例
2017/10/09 jQuery
jQuery+ajax实现动态添加表格tr td功能示例
2018/04/23 jQuery
原生JS实现DOM加载完成马上执行JS代码的方法
2018/09/07 Javascript
一些可能会用到的Node.js面试题
2019/06/15 Javascript
Vue 动态添加路由及生成菜单的方法示例
2019/06/20 Javascript
Python实现检测服务器是否可以ping通的2种方法
2015/01/01 Python
浅谈Python 中整型对象的存储问题
2016/05/16 Python
Python遍历目录中的所有文件的方法
2016/07/08 Python
Python中运算符"=="和"is"的详解
2016/10/08 Python
浅谈numpy中linspace的用法 (等差数列创建函数)
2017/06/07 Python
Python利用multiprocessing实现最简单的分布式作业调度系统实例
2017/11/14 Python
Python设计模式之策略模式实例详解
2019/01/21 Python
如何用Python破解wifi密码过程详解
2019/07/12 Python
win10下安装Anaconda的教程(python环境+jupyter_notebook)
2019/10/23 Python
使用PyQt5实现图片查看器的示例代码
2020/04/21 Python
Python3爬虫中识别图形验证码的实例讲解
2020/07/30 Python
基于HTML5的WebGL实现json和echarts图表展现在同一个界面
2017/10/26 HTML / CSS
法国珠宝店:CLEOR
2017/01/29 全球购物
外语学院毕业生的自我鉴定
2013/11/28 职场文书
2014迎接教师节演讲稿
2014/09/10 职场文书
2014年社区工作总结
2014/11/18 职场文书
2014年人大工作总结
2014/12/10 职场文书
2014年服务员个人工作总结
2014/12/23 职场文书
综合实践活动报告
2015/02/05 职场文书
幼儿园国庆节活动总结
2015/03/23 职场文书
2015年网管个人工作总结
2015/05/22 职场文书
Mysql分析设计表主键为何不用uuid
2022/03/31 MySQL