pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python和php通信乱码问题解决方法
Apr 15 Python
Python挑选文件夹里宽大于300图片的方法
Mar 05 Python
python中的内置函数max()和min()及mas()函数的高级用法
Mar 29 Python
Python之循环结构
Jan 15 Python
浅谈Python类中的self到底是干啥的
Nov 11 Python
tensorflow指定GPU与动态分配GPU memory设置
Feb 03 Python
python 非线性规划方式(scipy.optimize.minimize)
Feb 11 Python
python filecmp.dircmp实现递归比对两个目录的方法
May 22 Python
如何利用python web框架做文件流下载的实现示例
Jun 02 Python
10个python爬虫入门实例(小结)
Nov 01 Python
Python中OpenCV实现查找轮廓的实例
Jun 08 Python
PYTHON使用Matplotlib去实现各种条形图的绘制
Mar 22 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
最令PHP初学者头痛的十四个问题
2006/07/12 PHP
生成缩略图
2006/10/09 PHP
php框架CodeIgniter使用redis的方法分析
2018/04/13 PHP
PHP使用curl_multi实现并发请求的方法示例
2018/04/29 PHP
PHP利用curl发送HTTP请求的实例代码
2020/07/09 PHP
PHP7 list() 函数修改
2021/03/09 PHP
用javascript操作xml
2006/11/04 Javascript
JavaScript 节点操作 以及DOMDocument属性和方法
2007/12/06 Javascript
JS控制文本框textarea输入字数限制的方法
2013/06/17 Javascript
jquery实现动态菜单的实例代码
2013/11/28 Javascript
jQuery图片轮播滚动切换代码分享
2020/04/20 Javascript
基于jQuery和CSS3制作数字时钟附源码下载(jquery篇)
2015/11/24 Javascript
Javascript实现检测客户端类型代码封包
2015/12/03 Javascript
深入分析node.js的异步API和其局限性
2016/09/05 Javascript
nodejs进阶(6)—连接MySQL数据库示例
2017/01/07 NodeJs
JavaScript实现简易的天数计算器实例【附demo源码下载】
2017/01/18 Javascript
React Native之TextInput组件解析示例
2017/08/22 Javascript
浅谈mint-ui 填坑之路
2017/11/06 Javascript
Javascript中的作用域及块级作用域
2017/12/08 Javascript
Vue2.5 结合 Element UI 之 Table 和 Pagination 组件实现分页功能
2018/01/26 Javascript
js将键值对字符串转为json字符串的方法
2018/03/30 Javascript
Node.js折腾记一:读指定文件夹,输出该文件夹的文件树详解
2019/04/20 Javascript
vue监听浏览器原生返回按钮,进行路由转跳操作
2020/09/09 Javascript
[45:14]Optic vs VP 2018国际邀请赛淘汰赛BO3 第二场 8.24
2018/08/25 DOTA
python基于socket实现网络广播的方法
2015/04/29 Python
python通过函数属性实现全局变量的方法
2015/05/16 Python
opencv调整图像亮度对比度的示例代码
2019/09/27 Python
基于python traceback实现异常的获取与处理
2019/12/13 Python
报到证丢失证明
2014/01/11 职场文书
网上卖盒饭创业计划书
2014/01/26 职场文书
教师爱岗敬业演讲稿
2014/05/05 职场文书
中层干部培训方案
2014/06/16 职场文书
甜美蛋糕店的创业计划书模板,拿来即用!
2019/08/21 职场文书
Python通过m3u8文件下载合并ts视频的操作
2021/04/16 Python
详解thinkphp的Auth类认证
2021/05/28 PHP
对讲机的最大通讯距离是多少
2022/02/18 无线电