pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python绘制人人网好友关系图示例
Apr 01 Python
Python 检查数组元素是否存在类似PHP isset()方法
Oct 14 Python
简单介绍Python中的filter和lambda函数的使用
Apr 07 Python
Python命令行参数解析模块optparse使用实例
Apr 13 Python
使用Python发送邮件附件以定时备份MySQL的教程
Apr 25 Python
让python 3支持mysqldb的解决方法
Feb 14 Python
pygame实现俄罗斯方块游戏
Jun 26 Python
在Django下测试与调试REST API的方法详解
Aug 29 Python
Python如何基于rsa模块实现非对称加密与解密
Jan 03 Python
Pycharm pyuic5实现将ui文件转为py文件,让UI界面成功显示
Apr 08 Python
PyQT5 实现快捷键复制表格数据的方法示例
Jun 19 Python
Django框架模板用法详解
Jun 10 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
IIS下PHP连接数据库提示mysql undefined function mysql_connect()
2010/06/04 PHP
浅析PHP中的字符串编码转换(自动识别原编码)
2013/07/02 PHP
php批量删除超链接的实现方法
2015/10/19 PHP
PHP程序员的技术成长规划
2016/03/25 PHP
PHP面向对象程序设计内置标准类,普通数据类型转为对象类型示例
2019/06/12 PHP
javascript 面向对象,实现namespace,class,继承,重载
2009/10/29 Javascript
基于jquery的网站幻灯片切换效果焦点图代码
2013/09/15 Javascript
js导出table数据到excel即导出为EXCEL文档的方法
2013/10/10 Javascript
js获取电脑分辨率的思路及操作
2013/11/22 Javascript
jquery插件之定时查询待处理任务数量
2014/05/01 Javascript
一个可以增加和删除行的table并可编辑表格中内容
2014/06/16 Javascript
微信小程序 动态传参实例详解
2017/04/27 Javascript
Vue.js在使用中的一些注意知识点
2017/04/29 Javascript
JS回调函数基本定义与用法实例分析
2017/05/24 Javascript
JS实现的找零张数最小问题示例
2017/11/28 Javascript
微信小程序之分享页面如何返回首页的示例
2018/03/28 Javascript
vue如何通过id从列表页跳转到对应的详情页
2018/05/01 Javascript
Vue父子组建的简单通信之控制开关Switch的实现
2018/06/04 Javascript
angularjs $http调用接口的方式详解
2018/08/13 Javascript
vue-axios同时请求多个接口 等所有接口全部加载完成再处理操作
2020/11/09 Javascript
python threading模块操作多线程介绍
2015/04/08 Python
python实现在每个独立进程中运行一个函数的方法
2015/04/23 Python
Python 专题三 字符串的基础知识
2017/03/19 Python
Python实现文件信息进行合并实例代码
2018/01/17 Python
Python 多维List创建的问题小结
2019/01/18 Python
TensorFlow 多元函数的极值实例
2020/02/10 Python
Html5 Canvas实现图片标记、缩放、移动和保存历史状态功能 (附转换公式)
2020/03/18 HTML / CSS
史泰博(Staples)中国官方网站:办公用品一站式采购
2016/09/05 全球购物
Peter Alexander新西兰站:澳大利亚领先的睡衣设计师品牌
2016/12/10 全球购物
会计实习生工作总结的自我评价
2013/10/07 职场文书
校园自助餐厅的创业计划书
2013/12/26 职场文书
《落花生》教学反思
2014/02/25 职场文书
预备党员转正考核材料
2014/06/03 职场文书
餐饮服务员岗位职责
2015/02/09 职场文书
Python实现随机生成迷宫并自动寻路
2021/06/13 Python
MySQL提取JSON字段数据实现查询
2022/04/22 MySQL