pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python连接远程ftp服务器并列出目录下文件的方法
Apr 01 Python
python optparse模块使用实例
Apr 09 Python
python实现bucket排序算法实例分析
May 04 Python
Python SqlAlchemy动态添加数据表字段实例解析
Feb 07 Python
Python3.6连接Oracle数据库的方法详解
May 18 Python
IntelliJ IDEA安装运行python插件方法
Dec 10 Python
Python实现投影法分割图像示例(二)
Jan 17 Python
Python爬虫headers处理及网络超时问题解决方案
Jun 19 Python
keras分类之二分类实例(Cat and dog)
Jul 09 Python
3分钟看懂Python后端必须知道的Django的信号机制
Jul 26 Python
详解基于Facecognition+Opencv快速搭建人脸识别及跟踪应用
Jan 21 Python
Python实现的扫码工具居然这么好用!
Jun 07 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
PHP面向对象的使用教程 简单数据库连接
2006/11/25 PHP
用PHP制作的意见反馈表源码
2007/03/11 PHP
解析PHP中一些可能会被忽略的问题
2013/06/21 PHP
PHP模拟asp.net的StringBuilder类实现方法
2015/08/08 PHP
Laravel框架Request、Response及Session操作示例
2019/05/06 PHP
php数组遍历类与用法示例
2019/05/24 PHP
JavaScript 内置对象属性及方法集合
2010/07/04 Javascript
网站页面自动跳转实现方法PHP、JSP(上)
2010/08/01 Javascript
判断用户是否在线的代码
2011/03/05 Javascript
jquery ajax 局部无刷新更新数据的实现案例
2014/02/08 Javascript
jquery中子元素和后代元素的区别示例介绍
2014/04/02 Javascript
JS实现鼠标箭头变成一个燃烧烛光效果的方法
2015/02/28 Javascript
JS+CSS实现简单的二级下拉导航菜单效果
2015/09/21 Javascript
jQuery Checkbox 全选 反选的简单实例
2016/11/29 Javascript
JavaScript & jQuery完美判断图片是否加载完毕
2017/01/08 Javascript
Vue-router 切换组件页面时进入进出动画方法
2018/09/01 Javascript
玩转Koa之koa-router原理解析
2018/12/29 Javascript
了解javascript中的Dom操作
2019/05/27 Javascript
在Python中操作列表之List.append()方法的使用
2015/05/20 Python
Python3实现抓取javascript动态生成的html网页功能示例
2017/08/22 Python
Python中format()格式输出全解
2019/04/12 Python
Python实现将HTML转成PDF的方法分析
2019/05/04 Python
与Django结合利用模型对上传图片预测的实例详解
2019/08/07 Python
django项目中使用手机号登录的实例代码
2019/08/15 Python
python 的topk算法实例
2020/04/02 Python
HTML5实现动画效果的方式汇总
2016/02/29 HTML / CSS
美国高端婴童品牌:Hanna Andersson
2016/10/30 全球购物
介绍一下.net和Java的特点和区别
2012/09/26 面试题
国际贸易个人求职信范文
2014/01/04 职场文书
英文自荐信常用句子
2014/03/26 职场文书
竞选班长演讲稿400字
2014/08/22 职场文书
学生会感恩节活动方案
2014/10/11 职场文书
辞职信怎么写?
2019/05/21 职场文书
详解Redis基本命令与使用场景
2021/06/01 Redis
第四次工业革命,打工人与机器人的竞争
2022/04/21 数码科技
vue elementUI批量上传文件
2022/04/26 Vue.js