pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python time模块详解(常用函数实例讲解,非常好)
Apr 24 Python
Python写的一个简单DNS服务器实例
Jun 04 Python
python 编码规范整理
May 05 Python
Python数据结构之栈、队列及二叉树定义与用法浅析
Dec 27 Python
值得收藏的10道python 面试题
Apr 15 Python
Django CBV类的用法详解
Jul 26 Python
Django CSRF跨站请求伪造防护过程解析
Jul 31 Python
python基于json文件实现的gearman任务自动重启代码实例
Aug 13 Python
pytorch 修改预训练model实例
Jan 18 Python
Python钉钉报警及Zabbix集成钉钉报警的示例代码
Aug 17 Python
python中的插入排序的简单用法
Jan 19 Python
python如何为list实现find方法
May 30 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
php的控制语句
2006/10/09 PHP
php数组去除空值函数分享
2015/02/02 PHP
swoole_process实现进程池的方法示例
2018/10/29 PHP
深入理解JavaScript系列(4) 立即调用的函数表达式
2012/01/15 Javascript
疯狂Jquery第一天(Jquery学习笔记)
2012/05/11 Javascript
artDialog 4.1.5 Dreamweaver代码提示/补全插件 附下载
2012/07/31 Javascript
JS+ACTIVEX实现网页选择本地目录路径对话框
2013/03/18 Javascript
jquery使用正则表达式验证email地址的方法
2015/01/22 Javascript
jQuery插件windowScroll实现单屏滚动特效
2015/07/14 Javascript
Angular Js文件上传之form-data
2015/08/28 Javascript
js实现模拟银行卡账号输入显示效果
2015/11/18 Javascript
基于jquery实现图片上传本地预览功能
2016/01/08 Javascript
Javascript highcharts 饼图显示数量和百分比实例代码
2016/12/06 Javascript
浅谈js函数三种定义方式 & 四种调用方式 & 调用顺序
2017/02/19 Javascript
使用vue-cli导入Element UI组件的方法
2018/05/16 Javascript
详解React服务端渲染从入门到精通
2019/03/28 Javascript
手把手教你 CKEDITOR 4 实现Dialog 内嵌 IFrame操作详解
2019/06/18 Javascript
python使用mysql数据库示例代码
2017/05/21 Python
Python爬虫实现获取动态gif格式搞笑图片的方法示例
2018/12/24 Python
python async with和async for的使用
2019/06/20 Python
简单了解python变量的作用域
2019/07/30 Python
使用python绘制温度变化雷达图
2019/10/18 Python
Python中的sys.stdout.write实现打印刷新功能
2020/02/21 Python
美国最受欢迎的度假租赁网站:VRBO
2016/08/02 全球购物
微软香港官网及网上商店:Microsoft HK
2016/09/01 全球购物
城野医生官方海外旗舰店:风靡亚洲毛孔收敛水
2018/04/26 全球购物
彪马土耳其官网:PUMA土耳其
2019/07/14 全球购物
德国、奥地利和瑞士最大的旅行和度假门户网站:HolidayCheck
2019/11/14 全球购物
韩国乐天网上商城:Lotte iMall
2021/02/03 全球购物
一套带网友答案的.NET笔试题
2016/12/06 面试题
单位领导证婚词
2014/01/14 职场文书
社区党总支书记先进事迹材料
2014/01/24 职场文书
数学国培研修感言
2014/02/13 职场文书
买房委托公证书
2014/04/08 职场文书
焦裕禄纪念馆观后感
2015/06/09 职场文书
Linux中sftp常用命令整理
2022/06/28 Servers