keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的简单窗口倒计时界面实例
May 05 Python
python3设计模式之简单工厂模式
Oct 17 Python
python爬虫获取多页天涯帖子
Feb 23 Python
利用python为运维人员写一个监控脚本
Mar 25 Python
python 按不同维度求和,最值,均值的实例
Jun 28 Python
十行代码使用Python写一个USB病毒
Jun 21 Python
Python中新式类与经典类的区别详析
Jul 10 Python
tensorflow之自定义神经网络层实例
Feb 07 Python
Django 实现将图片转为Base64,然后使用json传输
Mar 27 Python
3种适用于Python的疯狂秘密武器及原因解析
Apr 29 Python
Python unittest单元测试框架实现参数化
Apr 29 Python
解决Django Haystack全文检索为空的问题
May 19 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
php中使用DOM类读取XML文件的实现代码
2011/12/14 PHP
PHP ignore_user_abort函数详细介绍和使用实例
2014/07/15 PHP
如何在PHP环境中使用ProtoBuf数据格式
2020/06/19 PHP
javascript模仿msgbox提示效果代码
2008/06/10 Javascript
在vs2010中调试javascript代码方法
2011/02/11 Javascript
jquery关于事件冒泡和事件委托的技巧及阻止与允许事件冒泡的三种实现方法
2015/11/27 Javascript
第五章之BootStrap 栅格系统
2016/04/25 Javascript
浅析javascript异步执行函数导致的变量变化问题解决思路
2016/05/13 Javascript
js replace(a,b)之替换字符串中所有指定字符的方法
2016/08/17 Javascript
JavaScript排序算法动画演示效果的实现方法
2016/10/18 Javascript
xmlplus组件设计系列之树(Tree)(9)
2017/05/02 Javascript
js推箱子小游戏步骤代码解析
2018/01/10 Javascript
ionic grid(栅格)九宫格制作详解
2018/06/30 Javascript
vue超时计算的组件实例代码
2018/07/09 Javascript
vue中的自定义分页插件组件的示例
2018/08/18 Javascript
Vue实现本地购物车功能
2018/12/05 Javascript
vue router 通过路由来实现切换头部标题功能
2019/04/24 Javascript
详解微信UnionID作用
2019/05/15 Javascript
微信小程序跳转到其他网页(外部链接)的实现方法
2019/09/20 Javascript
关于layui的下拉搜索框异步加载数据的解决方法
2019/09/28 Javascript
vue实现多个echarts根据屏幕大小变化而变化实例
2020/07/19 Javascript
Python中常用操作字符串的函数与方法总结
2016/02/04 Python
Python 闭包的使用方法
2017/09/07 Python
django框架面向对象ORM模型继承用法实例分析
2019/07/29 Python
django3.02模板中的超链接配置实例代码
2020/02/04 Python
深入了解Python装饰器的高级用法
2020/08/13 Python
HTML5 CSS3给网站设计带来出色效果
2009/07/16 HTML / CSS
世界上最大的乐器零售商:Guitar Center
2017/11/07 全球购物
GoPro摄像机美国官网:美国运动相机厂商
2018/07/03 全球购物
乐高瑞士官方商店:LEGO CH
2020/08/16 全球购物
什么是java序列化,如何实现java序列化
2012/11/14 面试题
党员批评与自我批评总结
2014/10/15 职场文书
扬州个园导游词
2015/02/06 职场文书
SpringBoot快速入门详解
2021/07/21 Java/Android
详解Python+OpenCV进行基础的图像操作
2022/02/15 Python
Promise静态四兄弟实现示例详解
2022/07/07 Javascript