keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
PyQt5利用QPainter绘制各种图形的实例
Oct 19 Python
在Python 字典中一键对应多个值的实例
Feb 03 Python
python实现nao机器人手臂动作控制
Apr 29 Python
Python 通过微信控制实现app定位发送到个人服务器再转发微信服务器接收位置信息
Aug 05 Python
关于Keras模型可视化教程及关键问题的解决
Jan 24 Python
Python pip配置国内源的方法
Feb 14 Python
keras的backend 设置 tensorflow,theano操作
Jun 30 Python
python使用re模块爬取豆瓣Top250电影
Oct 20 Python
如何通过安装HomeBrew来安装Python3
Dec 23 Python
python中用ggplot绘制画图实例讲解
Jan 26 Python
Python plt 利用subplot 实现在一张画布同时画多张图
Feb 26 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
May 28 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
php设计模式 Observer(观察者模式)
2011/06/26 PHP
php删除文件夹及其文件夹下所有文件的函数代码
2013/01/23 PHP
php 与 nginx 的处理方式及nginx与php-fpm通信的两种方式
2018/09/28 PHP
html数组字符串拼接的最快方法
2009/09/16 Javascript
jquery form表单提交插件asp.net后台中文解码
2010/06/12 Javascript
Extjs中通过Tree加载右侧TabPanel具体实现
2013/05/05 Javascript
jQuery实现鼠标移到元素上动态提示消息框效果
2013/10/20 Javascript
jQuery插件jFade实现鼠标经过的图片高亮其它变暗
2015/03/14 Javascript
JS实现表单验证功能(验证手机号是否存在,验证码倒计时)
2016/10/11 Javascript
bootstrap实现的自适应页面简单应用示例
2017/03/09 Javascript
js学习心得_一个简单的动画库封装tween.js
2017/07/14 Javascript
解决option标签selected="selected"属性失效的问题
2017/11/06 Javascript
nodejs爬虫初试superagent和cheerio
2018/03/05 NodeJs
JS实现点击按钮随机生成可拖动的不同颜色块示例
2019/01/30 Javascript
JavaScript中AOP的实现与应用
2019/05/06 Javascript
微信小程序实现消息框弹出动画
2020/04/18 Javascript
微信小程序开发常见问题及解决方案
2019/07/11 Javascript
微信小程序全局变量的设置、使用、修改过程解析
2019/09/24 Javascript
[52:15]2014 DOTA2国际邀请赛中国区预选赛5.21 HGT VS LGD-GAMING
2014/05/23 DOTA
Python运行的17个时新手常见错误小结
2012/08/07 Python
Python运维自动化之nginx配置文件对比操作示例
2018/08/29 Python
在python 中实现运行多条shell命令
2019/01/07 Python
解决Django中多条件查询的问题
2019/07/18 Python
python多线程实现同时执行两个while循环的操作
2020/05/02 Python
html5 input属性使用示例
2013/06/28 HTML / CSS
HTML5新增的表单元素和属性实例解析
2014/07/07 HTML / CSS
美国最好的钓鱼、狩猎和划船装备商店:Bass Pro Shops
2018/12/02 全球购物
Notino希腊:购买香水和美容产品
2019/07/25 全球购物
自考毕业生自我鉴定
2013/11/04 职场文书
中学门卫岗位职责
2013/12/26 职场文书
临时租车协议范本
2014/09/23 职场文书
公司2014年度工作总结
2014/12/10 职场文书
工作简报怎么写
2015/07/21 职场文书
2016年12月份红领巾广播稿
2015/12/21 职场文书
教你利用Nginx 服务搭建子域环境提升二维地图加载性能的步骤
2021/09/25 Servers
docker compose 部署 golang 的 Athens 私有代理问题
2022/04/28 Servers