keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3.6连接MySQL和表的创建与删除实例代码
Dec 28 Python
python机器学习理论与实战(四)逻辑回归
Jan 19 Python
Python微信操控itchat的方法
May 31 Python
Python简单处理坐标排序问题示例
Jul 11 Python
Pandas0.25来了千万别错过这10大好用的新功能
Aug 07 Python
python 怎样将dataframe中的字符串日期转化为日期的方法
Sep 26 Python
解决pyshp UnicodeDecodeError的问题
Dec 06 Python
基于python实现破解滑动验证码过程解析
May 28 Python
Python hashlib和hmac模块使用方法解析
Dec 08 Python
Python约瑟夫生者死者小游戏实例讲解
Jan 04 Python
matplotlib制作雷达图报错ValueError的实现
Jan 05 Python
python识别围棋定位棋盘位置
Jul 26 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
锁定年轻人的双倍活力 星巴克推出星倍醇即饮浓咖啡
2021/03/03 咖啡文化
PHP开启opcache提升代码性能
2015/04/26 PHP
PHP SOCKET编程详解
2015/05/22 PHP
四个常见html网页乱码问题及解决办法
2015/09/08 PHP
php设计模式之职责链模式实例分析【星际争霸游戏案例】
2020/03/27 PHP
greybox——不开新窗口看新的网页
2007/02/20 Javascript
Javascript Tab 导航插件 (23个)
2009/06/11 Javascript
web页面数据展示新想法(json)
2010/06/08 Javascript
浏览器解析js生成的html出现样式问题的解决方法
2012/04/16 Javascript
jquery插件制作 提示框插件实现代码
2012/08/17 Javascript
JSON.parse()和JSON.stringify()使用介绍
2014/06/20 Javascript
javascript获取dom的下一个节点方法
2014/09/05 Javascript
浅谈JavaScript数据类型及转换
2015/02/28 Javascript
使用Object.defineProperty实现简单的js双向绑定
2016/04/15 Javascript
JavaScript String 对象常用方法详解
2016/05/13 Javascript
微信小程序 动态的设置图片的高度和宽度详解及实例代码
2017/02/24 Javascript
Nodejs实现文件上传的示例代码
2017/09/26 NodeJs
微信小程序云开发(数据库)详解
2019/05/17 Javascript
Servlet返回的数据js解析2种方法
2019/12/12 Javascript
antd 表格列宽自适应方法以及错误处理操作
2020/10/27 Javascript
[02:34]肉山说——泡妞篇
2014/09/16 DOTA
Python语法快速入门指南
2015/10/12 Python
Python 调用Java实例详解
2017/06/02 Python
Django csrf 验证问题的实现
2018/10/09 Python
Django模板语言 Tags使用详解
2019/09/09 Python
python中的subprocess.Popen()使用详解
2019/12/25 Python
python图形界面开发之wxPython树控件使用方法详解
2020/02/24 Python
django跳转页面传参的实现
2020/09/17 Python
pandas 数据类型转换的实现
2020/12/29 Python
python pyg2plot的原理知识点总结
2021/02/28 Python
爱国口号
2014/06/19 职场文书
工作收入证明模板
2015/06/12 职场文书
百日宴上的祝酒词
2015/08/10 职场文书
2016公司新年问候语
2015/11/11 职场文书
解决redis sentinel 频繁主备切换的问题
2021/04/12 Redis
Android开发手册TextInputLayout样式使用示例
2022/06/10 Java/Android