pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 随机生成中文验证码的实例代码
Mar 20 Python
go和python调用其它程序并得到程序输出
Feb 10 Python
python单线程实现多个定时器示例
Mar 30 Python
Python标准库之sqlite3使用实例
Nov 25 Python
Python基类函数的重载与调用实例分析
Jan 12 Python
举例详解Python中yield生成器的用法
Aug 05 Python
python 2.7.14安装图文教程
Apr 08 Python
Python 通配符删除文件的实例
Apr 24 Python
python 3.7.0 安装配置方法图文教程
Aug 27 Python
简单了解Django ORM常用字段类型及参数配置
Jan 07 Python
用于ETL的Python数据转换工具详解
Jul 21 Python
Elasticsearch 批量操作
Apr 19 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
第八节 访问方式 [8]
2006/10/09 PHP
php FPDF类库应用实现代码
2009/03/20 PHP
常用的PHP数据库操作方法(MYSQL版)
2011/06/08 PHP
基于Zend的Captcha机制的应用
2013/05/02 PHP
php判断对象是派生自哪个类的方法
2015/06/20 PHP
JavaScript实现删除电脑的关机键
2016/07/26 PHP
PHP 与 UTF-8 的最佳实践详细介绍
2017/01/04 PHP
PHP使用new StdClass()创建空对象的方法分析
2017/06/06 PHP
Laravel框架源码解析之反射的使用详解
2020/05/14 PHP
jQuery 表单验证插件formValidation实现个性化错误提示
2009/06/23 Javascript
Ext 今日学习总结
2010/09/19 Javascript
js+数组实现网页上显示时间/星期几的实用方法
2013/01/18 Javascript
用于deeplink的js方法(判断手机是否安装app)
2014/04/02 Javascript
JS实现div居中示例
2014/04/17 Javascript
jquery序列化方法实例分析
2015/06/10 Javascript
jquery带翻页动画的电子杂志代码分享
2015/08/21 Javascript
JS实现的简单鼠标跟随DiV层效果完整实例
2015/10/31 Javascript
BootStrap 动态添加验证项和取消验证项的实现方法
2016/09/28 Javascript
angular中的cookie读写方法
2017/08/02 Javascript
JS简单实现点击跳转登陆邮箱功能的方法
2017/10/31 Javascript
vue webpack打包优化操作技巧
2018/02/22 Javascript
Vue.js 十五分钟入门图文教程
2018/09/12 Javascript
详解Vue中组件的缓存
2019/04/20 Javascript
基于mpvue的简单弹窗组件mptoast使用详解
2019/08/02 Javascript
ES6字符串的扩展实例
2020/12/21 Javascript
Python中文编码那些事
2014/06/25 Python
78行Python代码实现现微信撤回消息功能
2018/07/26 Python
python数据预处理 :数据抽样解析
2020/02/24 Python
python 如何用urllib与服务端交互(发送和接收数据)
2021/03/04 Python
一个不错的HTML5 Canvas多层点击事件监听实例
2014/04/29 HTML / CSS
幼儿园毕业典礼主持词
2014/03/21 职场文书
《蜗牛的奖杯》教后反思
2014/04/24 职场文书
党员一帮一活动总结
2014/07/08 职场文书
2015年世界粮食日演讲稿
2015/03/20 职场文书
社区文明倡议书
2015/04/28 职场文书
反邪教警示教育活动总结
2015/05/09 职场文书