pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python定时采集摄像头图像上传ftp服务器功能实现
Dec 23 Python
Python交互环境下实现输入代码
Jun 22 Python
python去除文件中重复的行实例
Jun 29 Python
Python中时间datetime的处理与转换用法总结
Feb 18 Python
详解python列表生成式和列表生成式器区别
Mar 27 Python
python面向对象实现名片管理系统文件版
Apr 26 Python
python保存log日志,实现用log日志画图
Dec 24 Python
python numpy实现多次循环读取文件 等间隔过滤数据示例
Mar 14 Python
解决django migrate报错ORA-02000: missing ALWAYS keyword
Jul 02 Python
python简单实现9宫格图片实例
Sep 03 Python
django项目中使用云片网发送短信验证码的实现
Jan 19 Python
学点简单的Django之第一个Django程序的实现
Feb 24 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
PHP 删除文件与文件夹操作 unlink()与rmdir()这两个函数的使用
2011/07/17 PHP
跟我学Laravel之请求与输入
2014/10/15 PHP
PHP QRCODE生成彩色二维码的方法
2016/05/19 PHP
文本框回车提交与禁止提交示例
2013/09/27 Javascript
node.js中的fs.lstat方法使用说明
2014/12/16 Javascript
javacript使用break内层跳出外层循环分析
2015/01/12 Javascript
Nodejs为什么选择javascript为载体语言
2015/01/13 NodeJs
jquery动态切换背景图片的简单实现方法
2016/05/14 Javascript
JavaScript实现无刷新上传预览图片功能
2017/08/02 Javascript
强大的JavaScript响应式图表Chartist.js的使用
2017/09/13 Javascript
解决JQuery全选/反选第二次失效的问题
2017/10/11 jQuery
NodeJS搭建HTTP服务器的实现步骤
2018/10/12 NodeJs
layui--js控制switch的切换方法
2019/09/03 Javascript
VsCode与Node.js知识点详解
2019/09/05 Javascript
[06:53]DOTA2每周TOP10 精彩击杀集锦vol.3
2014/06/25 DOTA
[04:49]2014DOTA2国际邀请赛 Newbee顺利挺进总决赛 ImbaTV独家专访
2014/07/19 DOTA
浅要分析Python程序与C程序的结合使用
2015/04/07 Python
Python列表list操作相关知识小结
2020/01/29 Python
Django Admin后台添加数据库视图过程解析
2020/04/01 Python
Python Selenium XPath根据文本内容查找元素的方法
2020/12/07 Python
HTML5是否真的可以取代Flash
2010/02/10 HTML / CSS
一站式跨境收款解决方案:Payoneer(派安盈)
2018/09/06 全球购物
Farfetch台湾官网:奢侈品牌时尚购物平台
2019/06/17 全球购物
桥梁与隧道工程专业本科生求职信
2013/10/08 职场文书
《诚实与信任》教学反思
2014/04/10 职场文书
2014年党支部学习材料
2014/05/19 职场文书
社团活动总结报告
2014/06/27 职场文书
市场策划求职信
2014/08/07 职场文书
2014年图书馆工作总结
2014/11/25 职场文书
区域销售经理岗位职责
2015/04/02 职场文书
教师岗位职责范本
2015/04/02 职场文书
个人更名证明
2015/06/23 职场文书
医疗纠纷调解协议书
2015/08/06 职场文书
Spring Data JPA使用JPQL与原生SQL进行查询的操作
2021/06/15 Java/Android
【海涛七七解说】DCG第二周:DK VS 天禄
2022/04/01 DOTA
cypress测试本地web应用
2022/06/01 Javascript