keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作SQLite简明教程
Jul 10 Python
Python最基本的输入输出详解
Apr 25 Python
Python发送email的3种方法
Apr 28 Python
Windows上使用virtualenv搭建Python+Flask开发环境
Jun 07 Python
python魔法方法-属性转换和类的表示详解
Jul 22 Python
python如何读写csv数据
Mar 21 Python
Python实现读取txt文件并转换为excel的方法示例
May 17 Python
Python subprocess模块常见用法分析
Jun 12 Python
Django 反向生成url实例详解
Jul 30 Python
爬虫代理的cookie如何生成运行
Sep 22 Python
Python中Yield的基本用法
Oct 18 Python
python3爬虫GIL修改多线程实例讲解
Nov 24 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
解析php安全性问题中的:Null 字符问题
2013/06/21 PHP
PHP 获取文件权限函数介绍
2013/07/11 PHP
PHP随手笔记整理之PHP脚本和JAVA连接mysql数据库
2015/11/25 PHP
Ajax和PHP正则表达式验证表单及验证码
2016/09/24 PHP
javascript字符串拼接的效率问题
2010/12/25 Javascript
JavaScript字符串插入、删除、替换函数使用示例
2013/07/25 Javascript
jquery mobile changepage的三种传参方法介绍
2013/09/13 Javascript
javascript中的return和闭包函数浅析
2014/06/06 Javascript
JavaScript中使用Object.create()创建对象介绍
2014/12/30 Javascript
js计算德州扑克牌面值的方法
2015/03/04 Javascript
JavaScript通过select动态更换图片的方法
2015/03/23 Javascript
jQuery实现无限往下滚动效果代码
2016/04/16 Javascript
Node.js学习之TCP/IP数据通讯(实例讲解)
2017/10/11 Javascript
bootstrap Table的一些小操作
2017/11/01 Javascript
vue组件数据传递、父子组件数据获取,slot,router路由功能示例
2019/03/19 Javascript
python 字典(dict)遍历的四种方法性能测试报告
2014/06/25 Python
在Python下使用Txt2Html实现网页过滤代理的教程
2015/04/11 Python
pygame学习笔记(2):画点的三种方法和动画实例
2015/04/15 Python
python实现可将字符转换成大写的tcp服务器实例
2015/04/29 Python
Python安装Numpy和matplotlib的方法(推荐)
2017/11/02 Python
Python3.7中安装openCV库的方法
2018/07/11 Python
python爬取网易云音乐评论
2018/11/16 Python
Python IDE Pycharm中的快捷键列表用法
2019/08/08 Python
浅谈如何使用python抓取网页中的动态数据实现
2020/08/17 Python
标签和贴纸印刷:Lightning Labels
2018/03/22 全球购物
餐饮业的创业计划书范文
2013/12/26 职场文书
大学毕业生个人自荐信范文
2014/01/08 职场文书
国家机关领导干部民主生活会对照检查材料思想汇报
2014/09/17 职场文书
一般党员对照检查材料
2014/09/24 职场文书
秦兵马俑导游词
2015/02/02 职场文书
观后感格式
2015/06/19 职场文书
通讯稿格式及范文
2015/07/22 职场文书
大学文艺委员竞选稿
2015/11/19 职场文书
2019幼儿园感恩节活动策划书
2019/11/28 职场文书
分享提高 Python 代码的可读性的技巧
2022/03/03 Python
Mysql中的触发器定义及语法介绍
2022/06/25 MySQL