pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 算法 排序实现快速排序
Jun 05 Python
Python中变量交换的例子
Aug 25 Python
python使用PyGame模块播放声音的方法
May 20 Python
Python编程之基于概率论的分类方法:朴素贝叶斯
Nov 11 Python
python3 selenium 切换窗口的几种方法小结
May 21 Python
python3+opencv3识别图片中的物体并截取的方法
Dec 05 Python
python生成器与迭代器详解
Jan 01 Python
对Python中的条件判断、循环以及循环的终止方法详解
Feb 08 Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 Python
python生成13位或16位时间戳以及反向解析时间戳的实例
Mar 03 Python
python实现数字炸弹游戏
Jul 17 Python
python判断元素是否存在的实例方法
Sep 24 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
使用PHPMailer发送邮件实例
2017/02/15 PHP
PHP 加密 Password Hashing API基础知识点
2020/03/02 PHP
Jquery 基础学习笔记之文档处理
2009/05/29 Javascript
jquery中的事件处理详细介绍
2013/06/24 Javascript
面试常见的js算法题
2017/03/23 Javascript
JavaScript 数组去重并统计重复元素出现的次数实例
2017/12/14 Javascript
详解webpack+express多页站点开发
2017/12/22 Javascript
Javascript中prototype与__proto__的关系详解
2018/03/11 Javascript
深入浅析Vue.js计算属性和侦听器
2018/05/05 Javascript
vue下history模式刷新后404错误解决方法
2018/08/18 Javascript
Vue通过ref父子组件拿值方法
2018/09/12 Javascript
vue实现的双向数据绑定操作示例
2018/12/04 Javascript
node和vue实现商城用户地址模块
2018/12/05 Javascript
3分钟读懂移动端rem使用方法(推荐)
2019/05/06 Javascript
jquery+css实现Tab栏切换的代码实例
2019/05/14 jQuery
[04:22]DOTA2大事件之护国神翼
2020/08/14 DOTA
Django如何自定义分页
2018/09/25 Python
在python中pandas读文件,有中文字符的方法
2018/12/12 Python
使用python将多个excel文件合并到同一个文件的方法
2019/07/09 Python
PyQt 图解Qt Designer工具的使用方法
2019/08/06 Python
python输出国际象棋棋盘的实例分享
2020/11/26 Python
详解HTML5通讯录获取指定多个人的信息
2016/12/20 HTML / CSS
利用html5的websocket实现websocket聊天室
2013/12/12 HTML / CSS
马来西亚在线时尚女装商店:KEI MAG
2017/09/28 全球购物
英国第一摩托车和摩托车越野配件商店:GhostBikes
2019/03/10 全球购物
美国家居装饰购物网站:Amanda Lindroth
2020/03/25 全球购物
税务干部鉴定材料
2014/02/11 职场文书
管理部副部长岗位职责范文
2014/03/09 职场文书
求职简历自我评价范例
2014/03/12 职场文书
艺术设计专业求职自荐信
2014/05/19 职场文书
办公室领导干部作风整顿个人整改措施
2014/09/17 职场文书
学校会议通知范文
2015/04/15 职场文书
导游词之广东佛山(南风古灶)
2019/09/24 职场文书
python playwright 自动等待和断言详解
2021/11/27 Python
python中的sys模块和os模块
2022/03/20 Python
i5-10400f处理相当于i7多少水平
2022/04/19 数码科技