pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现日常记账本小程序
Mar 10 Python
Python cookbook(字符串与文本)针对任意多的分隔符拆分字符串操作示例
Apr 19 Python
Python面向对象类继承和组合实例分析
May 28 Python
解决Pycharm调用Turtle时 窗口一闪而过的问题
Feb 16 Python
python多线程http压力测试脚本
Jun 25 Python
python 返回一个列表中第二大的数方法
Jul 09 Python
TensorFlow实现指数衰减学习率的方法
Feb 05 Python
Python调用接口合并Excel表代码实例
Mar 31 Python
Django封装交互接口代码
Jul 12 Python
Python实现七个基本算法的实例代码
Oct 08 Python
Python pygame实现中国象棋单机版源码
Jun 20 Python
深入理解Pytorch微调torchvision模型
Nov 11 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
PHP计划任务、定时执行任务的实现代码
2011/04/23 PHP
PHP laravel中的多对多关系实例详解
2017/06/07 PHP
PHP面向对象之工作单元(实例讲解)
2017/06/26 PHP
PHP按符号截取字符串的指定部分的实现方法
2018/09/10 PHP
用CSS+JS实现的进度条效果效果
2007/06/05 Javascript
js可突破windows弹退效果代码
2008/08/09 Javascript
js 完美图片新闻轮转效果,腾讯大粤网首页图片轮转改造而来
2011/11/21 Javascript
jQuery版仿Path菜单效果
2011/12/15 Javascript
nodejs的require模块(文件模块/核心模块)及路径介绍
2013/01/14 NodeJs
JQuery入门——用one()方法绑定事件处理函数(仅触发一次)
2013/02/05 Javascript
js单向链表的具体实现实例
2013/06/21 Javascript
Javascript 按位与赋值运算符 (&=)使用介绍
2014/02/04 Javascript
Node.js插件的正确编写方式
2014/08/03 Javascript
盘点javascript 正则表达式中 中括号的【坑】
2016/03/16 Javascript
利用jsonp跨域调用百度js实现搜索框智能提示
2016/08/24 Javascript
BootStrap Validator对于隐藏域验证和程序赋值即时验证的问题浅析
2016/12/01 Javascript
vue+element实现批量删除功能的示例
2018/02/28 Javascript
解决淘宝cnpm 安装后cnpm不是内部或外部命令的问题
2018/05/17 Javascript
浅谈Javascript常用正则表达式应用
2019/03/08 Javascript
你知道JavaScript Symbol类型怎么用吗
2020/01/08 Javascript
Vue组件为什么data必须是一个函数
2020/06/11 Javascript
深入了解Vue.js 混入(mixins)
2020/07/23 Javascript
vue中解决chrome浏览器自动播放音频和MP3语音打包到线上的实现方法
2020/10/09 Javascript
UEditor 自定义图片视频尺寸校验功能的实现代码
2020/10/20 Javascript
[03:07]2015国际邀请赛选手档案EHOME.rOtK 是什么让他落泪?
2015/07/31 DOTA
Java Web开发过程中登陆模块的验证码的实现方式总结
2016/05/25 Python
python检测空间储存剩余大小和指定文件夹内存占用的实例
2018/06/11 Python
Python中一般处理中文的几种方法
2019/03/06 Python
python按行读取文件并找出其中指定字符串
2019/08/08 Python
Django自定义模板过滤器和标签的实现方法
2019/08/21 Python
adidas泰国官网:adidas TH
2020/07/11 全球购物
科颜氏印度官网:Kiehl’s印度
2021/02/20 全球购物
商务日语专业的自荐信
2014/05/23 职场文书
中学生旷课检讨书2篇
2014/10/09 职场文书
Python 多线程处理任务实例
2021/11/07 Python
mysql 体系结构和存储引擎介绍
2022/05/06 MySQL