sklearn和keras的数据切分与交叉验证的实例详解


Posted in Python onJune 19, 2020

在训练深度学习模型的时候,通常将数据集切分为训练集和验证集.Keras提供了两种评估模型性能的方法:

使用自动切分的验证集

使用手动切分的验证集

一.自动切分

在Keras中,可以从数据集中切分出一部分作为验证集,并且在每次迭代(epoch)时在验证集中评估模型的性能.

具体地,调用model.fit()训练模型时,可通过validation_split参数来指定从数据集中切分出验证集的比例.

# MLP with automatic validation set
from keras.models import Sequential
from keras.layers import Dense
import numpy
# fix random seed for reproducibility
numpy.random.seed(7)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10)

validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。

注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。

二.手动切分

Keras允许在训练模型的时候手动指定验证集.

例如,用sklearn库中的train_test_split()函数将数据集进行切分,然后在keras的model.fit()的时候通过validation_data参数指定前面切分出来的验证集.

# MLP with manual validation set
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# split into 67% for train and 33% for test
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=seed)
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test,y_test), epochs=150, batch_size=10)

三.K折交叉验证(k-fold cross validation)

将数据集分成k份,每一轮用其中(k-1)份做训练而剩余1份做验证,以这种方式执行k轮,得到k个模型.将k次的性能取平均,作为该算法的整体性能.k一般取值为5或者10.

优点:能比较鲁棒性地评估模型在未知数据上的性能.

缺点:计算复杂度较大.因此,在数据集较大,模型复杂度较高,或者计算资源不是很充沛的情况下,可能不适用,尤其是在训练深度学习模型的时候.

sklearn.model_selection提供了KFold以及RepeatedKFold, LeaveOneOut, LeavePOut, ShuffleSplit, StratifiedKFold, GroupKFold, TimeSeriesSplit等变体.

下面的例子中用的StratifiedKFold采用的是分层抽样,它保证各类别的样本在切割后每一份小数据集中的比例都与原数据集中的比例相同.

# MLP for Pima Indians Dataset with 10-fold cross validation
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define 10-fold cross validation test harness
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
cvscores = []
for train, test in kfold.split(X, Y):
 # create model
  model = Sequential()
  model.add(Dense(12, input_dim=8, activation='relu'))
  model.add(Dense(8, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))
  # Compile model
  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  # Fit the model
  model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)
  # evaluate the model
  scores = model.evaluate(X[test], Y[test], verbose=0)
  print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
  cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))

补充知识:训练集,验证集和测试集

训练集:通过最小化目标函数(损失函数 + 正则项),用来训练模型的参数。当目标函数最小化时,完成对模型的训练。

验证集:用来选择模型的阶数。目标函数最小的模型对应的阶数,为模型的最终选择的阶数。

注:

1. 验证集会在训练过程中,反复使用,机器学习中作为选择不同模型的评判标准,深度学习中作为选择网络层数和每层节点数的评判标准。

2. 验证集的使用并非必不可少,如果网络的层数和节点数已经确定,则不需要这一步操作。

测试集:评估模型的泛化能力。根据选择的已经训练好的模型,评估它的泛化能力。

注:

测试集评判的是最终训练好的模型的泛化能力,只进行一次评判。

以上这篇sklearn和keras的数据切分与交叉验证的实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python函数学习笔记
Oct 07 Python
用Python展示动态规则法用以解决重叠子问题的示例
Apr 02 Python
Python中字典的基本知识初步介绍
May 21 Python
Python实现类似jQuery使用中的链式调用的示例
Jun 16 Python
Python实现公历(阳历)转农历(阴历)的方法示例
Aug 22 Python
python实现感知器算法(批处理)
Jan 18 Python
Python识别快递条形码及Tesseract-OCR使用详解
Jul 15 Python
python通过实例讲解反射机制
Oct 17 Python
Python 识别12306图片验证码物品的实现示例
Jan 20 Python
python数据类型可变不可变知识点总结
Mar 06 Python
python使用Thread的setDaemon启动后台线程教程
Apr 25 Python
python map比for循环快在哪
Sep 21 Python
Python虚拟环境的创建和包下载过程分析
Jun 19 #Python
通过实例解析python创建进程常用方法
Jun 19 #Python
keras model.fit 解决validation_spilt=num 的问题
Jun 19 #Python
为什么是 Python -m
Jun 19 #Python
Python 私有属性和私有方法应用场景分析
Jun 19 #Python
Python基于network模块制作电影人物关系图
Jun 19 #Python
keras中的History对象用法
Jun 19 #Python
You might like
40个迹象表明你还是PHP菜鸟
2008/09/29 PHP
php正则表达式基本知识与应用详解【经典教程】
2017/04/17 PHP
PHP模糊查询技术实例分析【附源码下载】
2019/03/07 PHP
PHP文件类型检查及fileinfo模块安装使用详解
2019/05/09 PHP
Laravel开启跨域请求的方法
2019/10/13 PHP
php + ajax 实现的写入数据库操作简单示例
2020/05/16 PHP
飞鱼(shqlsl) javascript作品集
2006/12/16 Javascript
JS之小练习代码
2008/10/12 Javascript
基于mootools 1.3框架下的图片滑动效果代码
2011/04/22 Javascript
浅析IE10兼容性问题(frameset的cols属性)
2014/01/03 Javascript
node.js中的fs.chown方法使用说明
2014/12/16 Javascript
jQuery Ajax使用实例
2015/04/16 Javascript
实例详解Nodejs 保存 payload 发送过来的文件
2016/01/14 NodeJs
详解Wondows下Node.js使用MongoDB的环境配置
2016/03/01 Javascript
第一篇初识bootstrap
2016/06/21 Javascript
JS实现给对象动态添加属性的方法
2017/01/05 Javascript
JavaScript实现的超简单计算器功能示例
2017/12/23 Javascript
webpack 4.0.0-beta.0版本新特性介绍
2018/02/10 Javascript
详解Vue 全局引入bass.scss 处理方案
2018/03/26 Javascript
js中split()方法得到的数组长度问题
2018/07/19 Javascript
js实现带箭头的进度流程
2020/03/26 Javascript
如何配置vue.config.js 处理static文件夹下的静态文件
2020/06/19 Javascript
Python多线程编程(五):死锁的形成
2015/04/05 Python
Python实现Singleton模式的方式详解
2019/08/08 Python
Django框架教程之中间件MiddleWare浅析
2019/12/29 Python
Python如何使用神经网络进行简单文本分类
2021/02/25 Python
详解CSS3的perspective属性设置3D变换距离的方法
2016/05/23 HTML / CSS
世界上最大的皮肤科医生拥有和经营的美容网站:LovelySkin
2021/01/03 全球购物
怎样从/向数据文件读/写结构
2014/11/23 面试题
大学三年的自我评价
2013/12/25 职场文书
老师的检讨书
2014/02/23 职场文书
2014两会优秀的心得体会范文
2014/03/17 职场文书
任命书模板
2014/06/04 职场文书
经典禁毒标语
2014/06/16 职场文书
Python数据分析之绘图和可视化详解
2021/06/02 Python
JavaScript 与 TypeScript之间的联系
2021/11/27 Javascript