pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python输出汉字字库及将文字转换为图片的方法
Jun 04 Python
Python解决走迷宫问题算法示例
Jul 27 Python
Python装饰器用法实例分析
Jan 14 Python
python实现QQ空间自动点赞功能
Apr 09 Python
python读文件的步骤
Oct 08 Python
Pycharm和Idea支持的vim插件的方法
Feb 21 Python
浅谈python中频繁的print到底能浪费多长时间
Feb 21 Python
python实现逆滤波与维纳滤波示例
Feb 26 Python
python GUI库图形界面开发之PyQt5单选按钮控件QRadioButton详细使用方法与实例
Feb 28 Python
浅谈pycharm导入pandas包遇到的问题及解决
Jun 01 Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 Python
Python collections模块的使用方法
Oct 09 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
PHP 危险函数解释 分析
2009/04/22 PHP
php获取qq用户昵称和在线状态(实例分析)
2013/10/27 PHP
PHP中4个加速、缓存扩展的区别和选用建议
2014/03/12 PHP
php通过asort()给关联数组按照值排序的方法
2015/03/18 PHP
PHP使用curl制作简易百度搜索
2016/11/03 PHP
PHP实现的curl批量请求操作示例
2018/06/06 PHP
浅谈PHP之ThinkPHP框架使用详解
2020/07/21 PHP
javascript与CSS复习(三)
2010/06/29 Javascript
javascript之querySelector和querySelectorAll使用说明
2011/10/09 Javascript
JS获取鼠标坐标的实例方法
2013/07/18 Javascript
实现51Map地图接口(示例代码)
2013/11/22 Javascript
Javascript中数组方法汇总(推荐)
2015/04/01 Javascript
聊一聊JavaScript作用域和作用域链
2016/05/03 Javascript
jquery设置css样式的多种方法(总结)
2017/02/21 Javascript
使用jQuery实现一个类似GridView的编辑,更新,取消和删除的功能
2017/03/15 Javascript
微信小程序删除处理详解
2017/08/16 Javascript
react高阶组件经典应用之权限控制详解
2017/09/07 Javascript
基于Vue 2.0 监听文本框内容变化及ref的使用说明介绍
2018/08/24 Javascript
angularjs1.5 组件内用函数向外传值的实例
2018/09/30 Javascript
微信小程序如何使用globalData的方法
2019/06/06 Javascript
了解Javascript中函数作为对象的魅力
2019/06/19 Javascript
在weex中愉快的使用scss的方法步骤
2020/01/02 Javascript
详解Vue.js3.0 组件是如何渲染为DOM的
2020/11/10 Javascript
python中os操作文件及文件路径实例汇总
2015/01/15 Python
python使用BeautifulSoup分析网页信息的方法
2015/04/04 Python
深入理解Python中命名空间的查找规则LEGB
2015/08/06 Python
谈谈Python中的while循环语句
2019/03/10 Python
美国滑雪板和装备购物网站:Skis.com
2018/12/20 全球购物
为什么使用接口?
2014/08/13 面试题
期终自我鉴定
2014/02/17 职场文书
会计专业个人自我鉴定
2014/03/21 职场文书
工商管理本科生求职信
2014/07/13 职场文书
2014年教师节讲话稿5篇
2014/09/10 职场文书
2014年大学学生会工作总结
2014/12/02 职场文书
微信小程序 根据不同用户切换不同TabBar
2022/04/21 Javascript
windows系统搭建WEB服务器详细教程
2022/08/05 Servers