sklearn和keras的数据切分与交叉验证的实例详解


Posted in Python onJune 19, 2020

在训练深度学习模型的时候,通常将数据集切分为训练集和验证集.Keras提供了两种评估模型性能的方法:

使用自动切分的验证集

使用手动切分的验证集

一.自动切分

在Keras中,可以从数据集中切分出一部分作为验证集,并且在每次迭代(epoch)时在验证集中评估模型的性能.

具体地,调用model.fit()训练模型时,可通过validation_split参数来指定从数据集中切分出验证集的比例.

# MLP with automatic validation set
from keras.models import Sequential
from keras.layers import Dense
import numpy
# fix random seed for reproducibility
numpy.random.seed(7)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10)

validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。

注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。

二.手动切分

Keras允许在训练模型的时候手动指定验证集.

例如,用sklearn库中的train_test_split()函数将数据集进行切分,然后在keras的model.fit()的时候通过validation_data参数指定前面切分出来的验证集.

# MLP with manual validation set
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# split into 67% for train and 33% for test
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=seed)
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test,y_test), epochs=150, batch_size=10)

三.K折交叉验证(k-fold cross validation)

将数据集分成k份,每一轮用其中(k-1)份做训练而剩余1份做验证,以这种方式执行k轮,得到k个模型.将k次的性能取平均,作为该算法的整体性能.k一般取值为5或者10.

优点:能比较鲁棒性地评估模型在未知数据上的性能.

缺点:计算复杂度较大.因此,在数据集较大,模型复杂度较高,或者计算资源不是很充沛的情况下,可能不适用,尤其是在训练深度学习模型的时候.

sklearn.model_selection提供了KFold以及RepeatedKFold, LeaveOneOut, LeavePOut, ShuffleSplit, StratifiedKFold, GroupKFold, TimeSeriesSplit等变体.

下面的例子中用的StratifiedKFold采用的是分层抽样,它保证各类别的样本在切割后每一份小数据集中的比例都与原数据集中的比例相同.

# MLP for Pima Indians Dataset with 10-fold cross validation
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define 10-fold cross validation test harness
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
cvscores = []
for train, test in kfold.split(X, Y):
 # create model
  model = Sequential()
  model.add(Dense(12, input_dim=8, activation='relu'))
  model.add(Dense(8, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))
  # Compile model
  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  # Fit the model
  model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)
  # evaluate the model
  scores = model.evaluate(X[test], Y[test], verbose=0)
  print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
  cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))

补充知识:训练集,验证集和测试集

训练集:通过最小化目标函数(损失函数 + 正则项),用来训练模型的参数。当目标函数最小化时,完成对模型的训练。

验证集:用来选择模型的阶数。目标函数最小的模型对应的阶数,为模型的最终选择的阶数。

注:

1. 验证集会在训练过程中,反复使用,机器学习中作为选择不同模型的评判标准,深度学习中作为选择网络层数和每层节点数的评判标准。

2. 验证集的使用并非必不可少,如果网络的层数和节点数已经确定,则不需要这一步操作。

测试集:评估模型的泛化能力。根据选择的已经训练好的模型,评估它的泛化能力。

注:

测试集评判的是最终训练好的模型的泛化能力,只进行一次评判。

以上这篇sklearn和keras的数据切分与交叉验证的实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python备份文件以及mysql数据库的脚本代码
Jun 10 Python
python的dict,set,list,tuple应用详解
Jul 24 Python
Python使用Flask框架获取当前查询参数的方法
Mar 21 Python
Windows下用py2exe将Python程序打包成exe程序的教程
Apr 08 Python
在Django的模板中使用认证数据的方法
Jul 23 Python
python批量添加zabbix Screens的两个脚本分享
Jan 16 Python
python类的方法属性与方法属性的动态绑定代码详解
Dec 27 Python
python和pygame实现简单俄罗斯方块游戏
Feb 19 Python
一文带你了解Python中的字符串是什么
Nov 20 Python
Python3 修改默认环境的方法
Feb 16 Python
TensorFlow2.0矩阵与向量的加减乘实例
Feb 07 Python
python 进阶学习之python装饰器小结
Sep 04 Python
Python虚拟环境的创建和包下载过程分析
Jun 19 #Python
通过实例解析python创建进程常用方法
Jun 19 #Python
keras model.fit 解决validation_spilt=num 的问题
Jun 19 #Python
为什么是 Python -m
Jun 19 #Python
Python 私有属性和私有方法应用场景分析
Jun 19 #Python
Python基于network模块制作电影人物关系图
Jun 19 #Python
keras中的History对象用法
Jun 19 #Python
You might like
德生PL450的电路分析和低放电路的改进办法
2021/03/02 无线电
PHP 可阅读随机字符串代码
2010/05/26 PHP
浅析php变量修饰符static的使用
2013/06/28 PHP
showModalDialog 和 showModelessDialog
2007/01/22 Javascript
jquery控制listbox中项的移动并排序的实现代码
2010/09/28 Javascript
统计jQuery中各字符串出现次数的工具
2012/05/03 Javascript
jQuery实现转动随机数抽奖效果的方法
2015/05/21 Javascript
JavaScript中使用Math.floor()方法对数字取整
2015/06/15 Javascript
纯javascript代码实现计算器功能(三种方法)
2015/09/07 Javascript
JavaScript浮点数及运算精度调整详解
2016/10/21 Javascript
详谈Angular路由与Nodejs路由的区别
2017/03/05 NodeJs
socket.io学习教程之基本应用(二)
2017/04/29 Javascript
Angular2 之 路由与导航详细介绍
2017/05/26 Javascript
详解AngularJS ng-class样式切换
2017/06/27 Javascript
如何理解Vue的作用域插槽的实现原理
2017/08/19 Javascript
vue-resource请求实现http登录拦截或者路由拦截的方法
2018/07/11 Javascript
利用webpack理解CommonJS和ES Modules的差异区别
2020/06/16 Javascript
详解Vue中的MVVM原理和实现方法
2020/07/15 Javascript
详解vue中使用transition和animation的实例代码
2020/12/12 Vue.js
[05:05]DOTA2亚洲邀请赛 战队出场仪式
2015/02/07 DOTA
[01:08:09]DOTA2上海特级锦标赛主赛事日 - 1 胜者组第一轮#1Liquid VS Alliance第二局
2016/03/02 DOTA
浅谈终端直接执行py文件,不需要python命令
2017/01/23 Python
Python中摘要算法MD5,SHA1简介及应用实例代码
2018/01/09 Python
Python中循环引用(import)失败的解决方法
2018/04/22 Python
python线程中同步锁详解
2018/04/27 Python
python 字典修改键(key)的几种方法
2018/08/10 Python
Django框架模板语言实例小结【变量,标签,过滤器,继承,html转义】
2019/05/23 Python
keras获得model中某一层的某一个Tensor的输出维度教程
2020/01/24 Python
python求前n个阶乘的和实例
2020/04/02 Python
HTML5 直播疯狂点赞动画实现代码 附源码
2020/04/14 HTML / CSS
Move Free官方海外旗舰店:美国骨关节健康专业品牌
2017/12/06 全球购物
塑料制成的可水洗的编织平底鞋和鞋子:Rothy’s
2018/09/16 全球购物
实习单位证明范例
2014/11/17 职场文书
如何写辞职信
2015/05/13 职场文书
浙江省杭州市平均工资标准是多少?
2019/07/09 职场文书
Python中的matplotlib绘制百分比堆叠柱状图,并为每一个类别设置不同的填充图案
2022/04/20 Python