pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用mongoengine操作MongoDB教程
Apr 24 Python
详解Python pygame安装过程笔记
Jun 05 Python
儿童编程python入门
May 08 Python
python3 flask实现文件上传功能
Mar 20 Python
python print输出延时,让其立刻输出的方法
Jan 07 Python
Python实现简单石头剪刀布游戏
Jan 20 Python
python查询文件夹下excel的sheet名代码实例
Apr 02 Python
python实现日志按天分割
Jul 22 Python
python 实现识别图片上的数字
Jul 30 Python
解决python多行注释引发缩进错误的问题
Aug 23 Python
pytorch中tensor.expand()和tensor.expand_as()函数详解
Dec 27 Python
Python中socket网络通信是干嘛的
May 27 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
收音机怀古---春雷3P7图片欣赏
2021/03/02 无线电
php include,include_once,require,require_once
2008/09/05 PHP
PHP isset()与empty()的使用区别详解
2010/08/29 PHP
file_get_contents获取不到网页内容的解决方法
2013/03/07 PHP
PHP的serialize序列化数据以及JSON格式化数据分析
2015/10/10 PHP
让whoops帮我们告别ThinkPHP6的异常页面
2020/03/02 PHP
javascript 对表格的行和列都能加亮显示
2008/12/26 Javascript
Javascript 中的 && 和 || 使用小结
2010/04/25 Javascript
jquery判断checkbox(复选框)是否被选中的代码
2010/10/20 Javascript
js中方法重载如何实现?以及函数的参数问题
2013/08/01 Javascript
JavaScript等比例缩放图片控制超出范围的图片
2013/08/06 Javascript
变量声明时命名与变量作为对象属性时命名的区别解析
2013/12/06 Javascript
js判断设备是否为PC并调整图片大小
2014/02/12 Javascript
JavaScript使用位运算符判断奇数和偶数的方法
2015/06/01 Javascript
JavaScript中的函数(二)
2015/12/23 Javascript
jQuery代码实现表格中点击相应行变色功能
2016/05/09 Javascript
node.js请求HTTPS报错:UNABLE_TO_VERIFY_LEAF_SIGNATURE\的解决方法
2016/12/18 Javascript
jquery 正整数数字校验正则表达式
2017/01/10 Javascript
JavaScript中七种流行的开源机器学习框架
2018/10/11 Javascript
python实现图片变亮或者变暗的方法
2015/06/01 Python
基于Django用户认证系统详解
2018/02/21 Python
解决Python下imread,imwrite不支持中文的问题
2018/12/05 Python
记录Python脚本的运行日志的方法
2019/06/05 Python
python opencv将图片转为灰度图的方法示例
2019/07/31 Python
python构造函数init实例方法解析
2020/01/19 Python
Html5 Canvas 实现一个“刮刮乐”游戏
2019/09/05 HTML / CSS
《狼和小羊》教学反思
2014/04/20 职场文书
党员国庆节演讲稿范文2014
2014/09/21 职场文书
员工工作自我评价
2014/09/26 职场文书
2015年网管个人工作总结
2015/05/22 职场文书
建国大业观后感800字
2015/06/01 职场文书
钢琴师观后感
2015/06/12 职场文书
品德与社会教学反思
2016/02/24 职场文书
Javascript中的解构赋值语法详解
2021/04/02 Javascript
go语言求任意类型切片的长度操作
2021/04/26 Golang
springboot利用redis、Redisson处理并发问题的操作
2021/06/18 Java/Android