pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python的判断语句模拟三目运算
Apr 24 Python
Python中每次处理一个字符的5种方法
May 21 Python
python中map()与zip()操作方法
Feb 27 Python
Python的Flask框架标配模板引擎Jinja2的使用教程
Jul 12 Python
Python判断两个对象相等的原理
Dec 12 Python
Golang GBK转UTF-8的例子
Aug 26 Python
基于Tensorflow高阶读写教程
Feb 10 Python
pycharm新建Vue项目的方法步骤(图文)
Mar 04 Python
通过代码实例了解Python sys模块
Sep 14 Python
Django自定义YamlField实现过程解析
Nov 11 Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 Python
opencv 分类白天与夜景视频的方法
Jun 05 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
PHP下一个非常全面获取图象信息的函数
2008/11/20 PHP
php中计算中文字符串长度、截取中文字符串的函数代码
2011/08/09 PHP
windwos下使用php连接oracle数据库的过程分享
2014/05/26 PHP
php截取字符串函数substr,iconv_substr,mb_substr示例以及优劣分析
2014/06/10 PHP
PHP解析RSS的方法
2015/03/05 PHP
thinkphp ajaxfileupload实现异步上传图片的示例
2017/08/28 PHP
浅谈PHP之ThinkPHP框架使用详解
2020/07/21 PHP
图片自动缩小的js代码,用以防止图片撑破页面
2007/03/12 Javascript
List the Codec Files on a Computer
2007/06/11 Javascript
Domino中运用jQuery读取视图内容的方法
2009/10/21 Javascript
提示$ is not defined错误分析及解决
2013/04/09 Javascript
利用jquery写的左右轮播图特效
2014/02/12 Javascript
JS动态修改图片的URL(src)的方法
2015/04/01 Javascript
给angular加上动画效遇到的问题总结
2016/02/17 Javascript
Node.js开发者必须了解的4个JS要点
2016/02/21 Javascript
JQuery.validate在ie8下不支持的快速解决方法
2016/05/18 Javascript
jQuery移除或禁用html元素点击事件常用方法小结
2017/02/10 Javascript
jquery加载单文件vue组件的方法
2017/06/20 jQuery
简单理解Vue中的nextTick方法
2018/01/30 Javascript
jQuery实现使用sort方法对json数据排序的方法
2018/04/17 jQuery
30分钟快速入门掌握ES6/ES2015的核心内容(下)
2018/04/18 Javascript
详解VUE调用本地json的使用方法
2019/05/15 Javascript
JavaScript 处理树数据结构的方法示例
2019/06/16 Javascript
微信小程序scroll-view的滚动条设置实现
2020/03/02 Javascript
Vue自定义表单内容检查rules实例
2020/10/30 Javascript
Python中的Numpy入门教程
2014/04/26 Python
Python过滤函数filter()使用自定义函数过滤序列实例
2014/08/26 Python
详解HTML5中的picture元素响应式处理图片
2018/01/03 HTML / CSS
乌克兰时尚鞋子和衣服购物网站:Born2be
2018/05/24 全球购物
社区志愿者心得体会
2014/01/03 职场文书
企业员工集体活动方案
2014/08/17 职场文书
产品委托授权书范本
2014/09/16 职场文书
2014年物资管理工作总结
2014/12/02 职场文书
三八节活动主持词
2015/07/04 职场文书
2016年度农村党员干部主题教育活动总结
2016/04/06 职场文书
Java使用HttpClient实现文件下载
2022/08/14 Java/Android