pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于BeautifulSoup实现抓取网页指定内容的方法
Jul 09 Python
利用scrapy将爬到的数据保存到mysql(防止重复)
Mar 31 Python
python中字符串变二维数组的实例讲解
Apr 03 Python
python 用lambda函数替换for循环的方法
Jun 09 Python
详解Python3的TFTP文件传输
Jun 26 Python
python爬虫爬取微博评论案例详解
Mar 27 Python
PyCharm+Qt Designer+PyUIC安装配置教程详解
Jun 13 Python
python实现智能语音天气预报
Dec 02 Python
python二维键值数组生成转json的例子
Dec 06 Python
Python脚本导出为exe程序的方法
Mar 25 Python
Django {{ MEDIA_URL }}无法显示图片的解决方式
Apr 07 Python
Python编写冷笑话生成器
Apr 20 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
Php中用PDO查询Mysql来避免SQL注入风险的方法
2013/04/25 PHP
PHP字符串的连接的简单实例
2013/12/30 PHP
phpmyadmin配置文件现在需要绝密的短密码(blowfish_secret)的2种解决方法
2014/05/07 PHP
php项目中百度 UEditor 简单安装调试和调用
2015/07/15 PHP
php关闭warning问题的解决方法
2016/05/17 PHP
在页面上点击任一链接时触发一个事件的代码
2007/04/07 Javascript
添加JavaScript重载函数的辅助方法2
2010/07/04 Javascript
Asp.net下利用Jquery Ajax实现用户注册检测(验证用户名是否存)
2010/09/12 Javascript
onclick与listeners的执行先后问题详细解剖
2013/01/07 Javascript
浅析JavaScript原型继承的陷阱
2013/12/03 Javascript
Jquery自定义button按钮的几种方法
2014/06/11 Javascript
javascript转换日期字符串为Date日期对象的方法
2015/02/13 Javascript
RequireJS简易绘图程序开发
2016/10/28 Javascript
在node中如何使用 ES6
2017/04/22 Javascript
Webpack性能优化 DLL 用法详解
2017/08/10 Javascript
Mac 安装 nodejs方法(图文详细步骤)
2017/10/30 NodeJs
解决element-ui中下拉菜单子选项click事件不触发的问题
2018/08/22 Javascript
全面了解JavaScript的作用域链
2019/04/03 Javascript
Vue监听滚动实现锚点定位(双向)示例
2019/11/13 Javascript
node.js开发辅助工具nodemon安装与配置详解
2020/02/06 Javascript
Vue3不支持Filters过滤器的问题
2020/09/24 Javascript
Linux中安装Python的交互式解释器IPython的教程
2016/06/13 Python
用Pygal绘制直方图代码示例
2017/12/07 Python
基于Python实现的微信好友数据分析
2018/02/26 Python
Python整数对象实现原理详解
2019/07/01 Python
解决安装pyqt5之后无法打开spyder的问题
2019/12/13 Python
python读取ini配置文件过程示范
2019/12/23 Python
详解Selenium 元素定位和WebDriver常用方法
2020/12/04 Python
20行代码教你用python给证件照换底色的方法示例
2021/02/05 Python
CSS3 @font-face属性使用指南
2014/12/12 HTML / CSS
HTML5中Localstorage的使用教程
2015/07/09 HTML / CSS
DC Shoes俄罗斯官网:美国滑板鞋和服饰品牌
2020/08/19 全球购物
远程教育心得体会
2014/01/03 职场文书
护士求职信
2014/07/05 职场文书
授权委托书(法人单位用)
2014/09/29 职场文书
幼师自荐信范文(2016推荐篇)
2016/01/28 职场文书