pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
PYTHON压平嵌套列表的简单实现
Jun 08 Python
Python中取整的几种方法小结
Jan 06 Python
Python基础教程之利用期物处理并发
Mar 29 Python
python提取具有某种特定字符串的行数据方法
Dec 11 Python
如何通过Python实现标签云算法
Jul 02 Python
Python将主机名转换为IP地址的方法
Aug 14 Python
PyTorch实现AlexNet示例
Jan 14 Python
Python configparser模块常用方法解析
May 22 Python
PyCharm中配置PySide2的图文教程
Jun 18 Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 Python
python 实现批量图片识别并翻译
Nov 02 Python
python re的findall和finditer的区别详解
Nov 15 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
php cURL和Rolling cURL并发方式比较
2013/10/30 PHP
PHP四大安全策略
2014/03/12 PHP
ThinkPHP自动完成中使用函数与回调方法实例
2014/11/29 PHP
如何实现php图片等比例缩放
2015/07/28 PHP
JavaScript 捕获窗口关闭事件
2009/07/26 Javascript
基于jQuery实现的扇形定时器附源码下载
2015/10/20 Javascript
轻松学习jQuery插件EasyUI EasyUI创建RSS Feed阅读器
2015/11/30 Javascript
bootstrap布局中input输入框右侧图标点击功能
2016/05/16 Javascript
JS之相等操作符详解
2016/09/13 Javascript
jQuery扩展实现text提示还能输入多少字节的方法
2016/11/28 Javascript
如何快速上手Vuex
2017/02/14 Javascript
Three.js获取鼠标点击的三维坐标示例代码
2017/03/24 Javascript
Angular 4 指令快速入门教程
2017/06/07 Javascript
ZeroClipboard.js使用一个flash复制多个文本框
2017/06/19 Javascript
Node.js利用js-xlsx处理Excel文件的方法详解
2017/07/05 Javascript
AngularJS实现的简单拖拽功能示例
2018/01/02 Javascript
angular6开发steps步骤条组件
2019/07/04 Javascript
JavaScript RegExp 对象用法详解
2019/09/24 Javascript
让IDE识别webpack的别名alias的实现方法
2020/05/06 Javascript
Js Snowflake(雪花算法)生成随机ID的实现方法
2020/08/26 Javascript
[31:00]2014 DOTA2华西杯精英邀请赛5 24 NewBee VS iG
2014/05/25 DOTA
Python之eval()函数危险性浅析
2014/07/03 Python
django接入新浪微博OAuth的方法
2015/06/29 Python
Python利用ElementTree模块处理XML的方法详解
2017/08/31 Python
Python中Proxypool库的安装与配置
2018/10/19 Python
pycharm开发一个简单界面和通用mvc模板(操作方法图解)
2020/05/27 Python
英国时尚家具、家居饰品及礼品商店:Graham & Green
2016/09/15 全球购物
一家专门做特卖的网站:唯品会
2016/10/09 全球购物
美国香薰蜡烛品牌:PADDYWAX
2018/10/06 全球购物
阿玛尼意大利官网:Armani意大利
2018/10/30 全球购物
数控技术应用个人求职信范文
2014/02/03 职场文书
开学典礼感言
2014/02/16 职场文书
检查机关党的群众路线个人整改措施
2014/10/04 职场文书
婚庆答谢词大全
2015/09/29 职场文书
Nginx工作原理和优化总结。
2021/04/02 Servers
JS高级程序设计之class继承重点详解
2022/07/07 Javascript