pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python生成随机数的方法
Jan 14 Python
python快速查找算法应用实例
Sep 26 Python
python从网络读取图片并直接进行处理的方法
May 22 Python
Python语言实现机器学习的K-近邻算法
Jun 11 Python
windows下ipython的安装与使用详解
Oct 20 Python
Python自定义函数定义,参数,调用代码解析
Dec 27 Python
替换python字典中的key值方法
Jul 06 Python
解决python中 f.write写入中文出错的问题
Oct 31 Python
Python csv模块使用方法代码实例
Aug 29 Python
Python爬虫库BeautifulSoup获取对象(标签)名,属性,内容,注释
Jan 25 Python
基于python计算滚动方差(标准差)talib和pd.rolling函数差异详解
Jun 08 Python
flask项目集成swagger的方法
Dec 09 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
Linux下 php5 MySQL5 Apache2 phpMyAdmin ZendOptimizer安装与配置[图文]
2008/11/18 PHP
php获取ip的三个属性区别介绍(HTTP_X_FORWARDED_FOR,HTTP_VIA,REMOTE_ADDR)
2012/09/23 PHP
Win7 64位系统下PHP连接Oracle数据库
2014/08/20 PHP
FleaPHP框架数据库查询条件($conditions)写法总结
2016/03/19 PHP
Javascript操作select方法大全[新增、修改、删除、选中、清空、判断存在等]
2008/09/26 Javascript
JQuery学习笔记 nt-child的使用
2011/01/17 Javascript
捕获浏览器关闭、刷新事件不同情况下的处理方法
2013/06/02 Javascript
javascript:window.open弹出窗口的位置问题
2014/03/18 Javascript
Underscore.js 1.3.3 中文注释翻译说明
2015/06/25 Javascript
纯js代码制作的网页时钟特效【附实例】
2016/03/30 Javascript
Node使用Selenium进行前端自动化操作的代码实现
2019/10/10 Javascript
python求解水仙花数的方法
2015/05/11 Python
让Python更加充分的使用Sqlite3
2017/12/11 Python
Django中Forms的使用代码解析
2018/02/10 Python
华为2019校招笔试题之处理字符串(python版)
2019/06/25 Python
Django REST框架创建一个简单的Api实例讲解
2019/11/05 Python
tensorflow 获取所有variable或tensor的name示例
2020/01/04 Python
python实现简单井字棋游戏
2020/03/04 Python
django实现HttpResponse返回json数据为中文
2020/03/27 Python
Python任务调度利器之APScheduler详解
2020/04/02 Python
python读取excel数据并且画图的实现示例
2021/02/08 Python
聊聊Python pandas 中loc函数的使用,及跟iloc的区别说明
2021/03/03 Python
HTML5 body设置全屏背景图片的示例代码
2020/12/08 HTML / CSS
德国网上宠物店:Zoobio
2018/05/23 全球购物
英国剑桥包中文官网:The Cambridge Satchel Company中国
2018/11/06 全球购物
几道Web/Ajax的面试题
2016/11/05 面试题
外贸业务员的岗位职责
2013/11/23 职场文书
董事长秘书岗位职责
2013/11/29 职场文书
班长自荐书范文
2014/02/11 职场文书
医生个人自我剖析材料
2014/10/08 职场文书
辞职信怎么写
2015/02/27 职场文书
岗位聘任报告
2015/03/02 职场文书
公司财务部岗位职责
2015/04/14 职场文书
学雷锋团日活动总结
2015/05/06 职场文书
2015年妇女工作总结
2015/05/14 职场文书
2016年6月份红领巾广播稿
2015/12/21 职场文书