pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中list列表的高级函数
May 17 Python
python3实现ftp服务功能(服务端 For Linux)
Mar 24 Python
基于python(urlparse)模板的使用方法总结
Oct 13 Python
python list是否包含另一个list所有元素的实例
May 04 Python
使用python读取.text文件特定行的数据方法
Jan 28 Python
python创造虚拟环境方法总结
Mar 04 Python
python语言基本语句用法总结
Jun 11 Python
python根据多个文件名批量查找文件
Aug 13 Python
python 发送json数据操作实例分析
Oct 15 Python
pytorch点乘与叉乘示例讲解
Dec 27 Python
python GUI库图形界面开发之PyQt5浏览器控件QWebEngineView详细使用方法
Feb 26 Python
详解selenium + chromedriver 被反爬的解决方法
Oct 28 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
PHP新手上路(三)
2006/10/09 PHP
PHP常用代码
2006/11/23 PHP
基于php设计模式中单例模式的应用分析
2013/05/15 PHP
php通过数组实现多条件查询实现方法(字符串分割)
2014/05/06 PHP
php文件包含的几种方式总结
2019/09/19 PHP
JS简单的图片放大缩小的两种方法
2013/11/11 Javascript
零基础搭建Node.js、Express、Ejs、Mongodb服务器及应用开发入门
2014/12/20 Javascript
jquery实现可拖拽弹出层特效
2015/01/04 Javascript
Node.js开发之访问Redis数据库教程
2015/01/14 Javascript
AngularJS基础知识笔记之过滤器
2015/05/10 Javascript
浅谈Jquery核心函数
2015/06/18 Javascript
jQuery超精致图片轮播幻灯片特效代码分享
2015/09/10 Javascript
AngularJS ng-style中使用filter
2016/09/21 Javascript
js循环map 获取所有的key和value的实现代码(json)
2018/05/09 Javascript
layui 监听表格复选框选中值的方法
2018/08/15 Javascript
js中this的指向问题归纳总结
2018/11/28 Javascript
JavaScript实现抖音罗盘时钟
2019/10/11 Javascript
JavaScript中layim之整合右键菜单的示例代码
2021/02/06 Javascript
python快速排序代码实例
2013/11/21 Python
jupyter安装小结
2016/03/13 Python
python实现数据预处理之填充缺失值的示例
2017/12/22 Python
Python装饰器简单用法实例小结
2018/12/03 Python
Python 实现文件打包、上传与校验的方法
2019/02/13 Python
Python创建字典的八种方式
2019/02/27 Python
Django用户认证系统 User对象解析
2019/08/02 Python
sklearn+python:线性回归案例
2020/02/24 Python
Django ModelForm组件原理及用法详解
2020/10/12 Python
AmazeUI图片轮播效果的示例代码
2020/08/20 HTML / CSS
如何写出高性能的JSP和Servlet
2013/01/22 面试题
历史专业个人求职信范文
2013/12/07 职场文书
文明礼仪演讲稿
2014/05/12 职场文书
酒店端午节活动方案
2014/08/26 职场文书
教师自我剖析材料范文
2014/09/30 职场文书
课题研究阶段性总结
2015/08/13 职场文书
Nginx Rewrite使用场景及配置方法解析
2021/04/01 Servers
TV动画《间谍过家家》公开PV
2022/03/20 日漫