pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
全面解读Python Web开发框架Django
Jun 30 Python
Python调用C/C++动态链接库的方法详解
Jul 22 Python
python在windows和linux下获得本机本地ip地址方法小结
Mar 20 Python
Python中常见的异常总结
Feb 20 Python
python调用API实现智能回复机器人
Apr 10 Python
python实现多进程代码示例
Oct 31 Python
Python 实现域名解析为ip的方法
Feb 14 Python
如何用Python做一个微信机器人自动拉群
Jul 03 Python
Python简易版停车管理系统
Aug 12 Python
python 矢量数据转栅格数据代码实例
Sep 30 Python
python3中的eval和exec的区别与联系
Oct 10 Python
Python上下文管理器用法及实例解析
Nov 11 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
PHP实现文件下载详解
2014/11/27 PHP
magento后台无法登录解决办法的两种方法
2016/12/09 PHP
thinkphp实现附件上传功能
2017/05/26 PHP
JavaScript 基于原型的对象(创建、调用)
2009/10/16 Javascript
基于JQuery的简单实现折叠菜单代码
2010/09/15 Javascript
从盛大通行证上摘下来的身份证验证js代码
2011/01/11 Javascript
JavaScript 基础篇之运算符、语句(二)
2012/04/07 Javascript
Jquery 数据选择插件Pickerbox使用介绍
2012/08/24 Javascript
jquery的$getjson调用并获取远程的JSON字符串问题
2012/12/10 Javascript
原生JS实现加入收藏夹的代码
2013/10/24 Javascript
javascript实现依次输入input自动定焦
2014/12/23 Javascript
jquery插件corner实现圆角边框的方法
2015/03/09 Javascript
JavaScript实现数字数组正序排列的方法
2015/04/06 Javascript
JavaScript中的toLocaleLowerCase()方法使用详解
2015/06/06 Javascript
JQuery实现样式设置、追加、移除与切换的方法
2015/06/11 Javascript
jQuery+ajax实现实用的点赞插件代码
2016/07/06 Javascript
快速掌握jQuery插件WebUploader文件上传
2016/11/07 Javascript
jquery实现轮播图效果
2017/02/13 Javascript
mpvue中配置vuex并持久化到本地Storage图文教程解析
2018/03/15 Javascript
详解vue的双向绑定原理及实现
2019/05/05 Javascript
了解JavaScript函数中的默认参数
2019/05/30 Javascript
Vue路由模块化配置的完整步骤
2019/08/14 Javascript
解决LayUI加上form.render()下拉框和单选以及复选框不出来的问题
2019/09/27 Javascript
vue3.0实现点击切换验证码(组件)及校验
2020/11/18 Vue.js
python查看FTP是否能连接成功的方法
2015/07/30 Python
python 文件操作api(文件操作函数)
2016/08/28 Python
python安装PIL模块时Unable to find vcvarsall.bat错误的解决方法
2016/09/19 Python
基于Python Numpy的数组array和矩阵matrix详解
2018/04/04 Python
Pandas之Dropna滤除缺失数据的实现方法
2019/06/25 Python
pandas基于时间序列的固定时间间隔求均值的方法
2019/07/04 Python
python进程池实现的多进程文件夹copy器完整示例
2019/11/27 Python
合作经营协议书范本
2014/09/16 职场文书
会计求职自荐信范文
2015/03/04 职场文书
2016年劳模先进事迹材料
2016/02/25 职场文书
navicat 连接Ubuntu虚拟机的mysql的操作方法
2022/04/02 MySQL
html解决浏览器记住密码输入框的问题
2023/05/07 HTML / CSS