keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python模块学习 datetime介绍
Aug 27 Python
Python的函数的一些高阶特性
Apr 27 Python
python中的闭包用法实例详解
May 05 Python
python dict 字典 以及 赋值 引用的一些实例(详解)
Jan 20 Python
Python针对给定列表中元素进行翻转操作的方法分析
Apr 27 Python
详解Pytorch 使用Pytorch拟合多项式(多项式回归)
May 24 Python
Python对数据进行插值和下采样的方法
Jul 03 Python
Python3随机漫步生成数据并绘制
Aug 27 Python
pytorch使用Variable实现线性回归
May 21 Python
Python利用Xpath选择器爬取京东网商品信息
Jun 01 Python
Python模拟登录和登录跳转的参考示例
Oct 30 Python
python中Array和DataFrame相互转换的实例讲解
Feb 03 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
php 定界符格式引起的错误
2011/05/24 PHP
由prototype_1.3.1进入javascript殿堂-类的初探
2006/11/06 Javascript
不同浏览器的怪癖小结
2010/07/11 Javascript
基于jQuery的Spin Button自定义文本框数值自增或自减
2010/07/17 Javascript
使用node.js 获取客户端信息代码分享
2014/11/26 Javascript
js过滤HTML标签完整实例
2015/11/26 Javascript
JavaScript实现斗地主游戏的思路
2016/02/29 Javascript
jQuery animate和CSS3相结合实现缓动追逐效果附源码下载
2016/04/18 Javascript
jQuery中数据缓存$.data的用法及源码完全解析
2016/04/29 Javascript
JavaScript String 对象常用方法详解
2016/05/13 Javascript
如何使用jquery实现文字上下滚动效果
2016/10/12 Javascript
JS 插件dropload下拉刷新、上拉加载使用小结
2017/04/13 Javascript
详解如何在vue中使用sass
2017/06/21 Javascript
Vue多组件仓库开发与发布详解
2019/02/28 Javascript
vue 点击其他区域关闭自定义div操作
2020/07/17 Javascript
JS实现炫酷轮播图
2020/11/15 Javascript
vue 动态添加的路由页面刷新时失效的原因及解决方案
2021/02/26 Vue.js
介绍Python中的fabs()方法的使用
2015/05/14 Python
Python编程之变量赋值操作实例分析
2017/07/24 Python
Python实现计算圆周率π的值到任意位的方法示例
2018/05/08 Python
Linux下python3.7.0安装教程
2018/07/30 Python
python实现文本界面网络聊天室
2018/12/12 Python
python实现蒙特卡罗方法教程
2019/01/28 Python
Python告诉你木马程序的键盘记录原理
2019/02/02 Python
Python 可变类型和不可变类型及引用过程解析
2019/09/27 Python
Python:二维列表下标互换方式(矩阵转置)
2019/12/02 Python
浅谈Python的方法解析顺序(MRO)
2020/03/05 Python
使用css3实现的tab选项卡代码分享
2014/12/09 HTML / CSS
会计专业自我鉴定范文
2013/10/06 职场文书
2014全国两会学习心得体会2000字
2014/03/10 职场文书
物联网工程专业推荐信
2014/09/08 职场文书
单位政审意见范文
2015/06/04 职场文书
小型婚礼主持词
2015/06/30 职场文书
25句企业管理语录:助你迅速打开思路,句句经典!
2020/01/14 职场文书
MySQL数据迁移相关总结
2021/04/29 MySQL
JS如何实现基于websocket的多端桥接平台
2021/05/14 Javascript