pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中pillow知识点学习
Apr 30 Python
Python3.4 tkinter,PIL图片转换
Jun 21 Python
python实现雨滴下落到地面效果
Jun 21 Python
python被修饰的函数消失问题解决(基于wraps函数)
Nov 04 Python
Python dict和defaultdict使用实例解析
Mar 12 Python
Python中如何引入第三方模块
May 27 Python
Python如何将模块打包并发布
Aug 30 Python
python对批量WAV音频进行等长分割的方法实现
Sep 25 Python
解决Pycharm 运行后没有输出的问题
Feb 05 Python
python中numpy.empty()函数实例讲解
Feb 05 Python
写一个Python脚本自动爬取Bilibili小视频
Apr 24 Python
用Python爬虫破解滑动验证码的案例解析
May 06 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
我的论坛源代码(一)
2006/10/09 PHP
浅析PHP 按位与或 (^ 、&)
2013/06/21 PHP
Javascript动态绑定事件的简单实现代码
2010/12/25 Javascript
jQuery EasyUI API 中文文档 - Calendar日历使用
2011/10/19 Javascript
JS中把字符转成ASCII值的函数示例代码
2013/11/21 Javascript
实用框架(iframe)操作代码
2014/10/23 Javascript
a标签的href与onclick事件的区别详解
2014/11/12 Javascript
使用HTML+CSS+JS制作简单的网页菜单界面
2015/07/27 Javascript
原生JS实现平滑回到顶部组件
2016/03/16 Javascript
原生js实现百叶窗效果及原理介绍
2016/04/12 Javascript
Jq通过td获取同行其它列td的方法
2016/10/05 Javascript
Easyui Tree获取当前选择节点的所有顶级父节点
2017/02/14 Javascript
vue实现商城购物车功能
2017/11/27 Javascript
Webpack优化配置缩小文件搜索范围
2017/12/25 Javascript
zTree 树插件实现全国五级地区点击后加载的示例
2018/02/05 Javascript
vue动态删除从数据库倒入列表的某一条方法
2018/09/29 Javascript
如何用JavaScript实现功能齐全的单链表详解
2019/02/11 Javascript
layui table去掉右侧滑动条的实现方法
2019/09/05 Javascript
JS如何实现网站中PC端和手机端自动识别并跳转对应的代码
2020/01/08 Javascript
wxPython之解决闪烁的问题
2018/01/15 Python
python web基础之加载静态文件实例
2018/03/20 Python
Python实现的逻辑回归算法示例【附测试csv文件下载】
2018/12/28 Python
PyQt5内嵌浏览器注入JavaScript脚本实现自动化操作的代码实例
2019/02/13 Python
Win10下python 2.7与python 3.7双环境安装教程图解
2019/10/12 Python
Python利用matplotlib绘制折线图的新手教程
2020/11/05 Python
会计核算科岗位职责
2014/03/19 职场文书
银行贷款收入证明
2014/10/17 职场文书
2014年安全保卫工作总结
2014/11/13 职场文书
荆州古城导游词
2015/02/06 职场文书
2015年个人审计工作总结
2015/04/07 职场文书
工会文体活动总结
2015/05/07 职场文书
酒桌上的开场白
2015/06/01 职场文书
2015团员个人年度总结
2015/11/24 职场文书
优质服务心得体会(共4篇)
2016/01/22 职场文书
HTML通过表单实现酒店筛选功能
2021/05/18 HTML / CSS
Python中的socket网络模块介绍
2022/07/23 Python