keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Django中创建URLconf相关的通用视图的方法
Jul 20 Python
python爬虫爬取淘宝商品信息
Feb 23 Python
python代码过长的换行方法
Jul 19 Python
实例介绍Python中整型
Feb 11 Python
Python 过滤错误log并导出的实例
Dec 26 Python
对Pytorch中Tensor的各种池化操作解析
Jan 03 Python
Python turtle画图库&&画姓名实例
Jan 19 Python
Python如何操作office实现自动化及win32com.client的运用
Apr 01 Python
在matplotlib中改变figure的布局和大小实例
Apr 23 Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 Python
matplotlib阶梯图的实现(step())
Mar 02 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
曾在DC漫画界反派角色扮演的演员,谁才是你心目中的小丑之王?
2020/04/09 欧美动漫
Session 失效的原因汇总及解决丢失办法
2015/09/30 PHP
php实现简单的上传进度条
2015/11/17 PHP
PHP微信开发之查询微信精选文章
2016/06/23 PHP
利用laravel搭建一个迷你博客实战教程
2017/08/13 PHP
Laravel框架模板继承操作示例
2018/06/11 PHP
PHP实现redis限制单ip、单用户的访问次数功能示例
2018/06/16 PHP
Extjs4 Treegrid 使用心得分享(经验篇)
2013/07/01 Javascript
JavaScript简单实现网页回到顶部功能
2013/11/12 Javascript
js截取小数点后几位的写法
2013/11/14 Javascript
JS如何判断移动端访问设备并解析对应CSS
2013/11/27 Javascript
jquery实现弹出层完美居中效果
2014/03/03 Javascript
轻松创建nodejs服务器(2):nodejs服务器的构成分析
2014/12/18 NodeJs
JavaScript中停止执行setInterval和setTimeout事件的方法
2015/05/14 Javascript
用javascript实现自动输出网页文本
2015/07/30 Javascript
基于zepto的移动端轻量级日期插件--date_picker
2016/03/04 Javascript
JS实现把鼠标放到链接上出现滚动文字的方法
2016/04/06 Javascript
jQuery Dialog 打开时自动聚焦的解决方法(两种方法)
2016/11/24 Javascript
Extjs 中的 Treepanel 实现菜单级联选中效果及实例代码
2017/08/22 Javascript
jQuery EasyUI 折叠面板accordion的使用实例(分享)
2017/12/25 jQuery
关于axios如何全局注册浅析
2018/01/14 Javascript
小程序开发中如何使用async-await并封装公共异步请求的方法
2019/01/20 Javascript
vue中axios防止多次触发终止多次请求的示例代码(防抖)
2020/02/16 Javascript
vue 二维码长按保存和复制内容操作
2020/09/22 Javascript
vue实现简单的登录弹出框
2020/10/26 Javascript
python中使用urllib2获取http请求状态码的代码例子
2014/07/07 Python
Python内置数据类型详解
2014/08/18 Python
Python列表list数组array用法实例解析
2014/10/28 Python
Python中的Classes和Metaclasses详解
2015/04/02 Python
Python 字典与字符串的互转实例
2017/01/13 Python
python OpenCV学习笔记之绘制直方图的方法
2018/02/08 Python
python多任务之协程的使用详解
2019/08/26 Python
python实现异常信息堆栈输出到日志文件
2019/12/26 Python
Django数据统计功能count()的使用
2020/11/30 Python
matplotlib部件之套索Lasso的使用
2021/02/24 Python
2016年春季运动会加油稿
2015/07/22 职场文书