keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取指定网页上所有超链接的方法
Apr 04 Python
读写json中文ASCII乱码问题的解决方法
Nov 05 Python
获取Django项目的全部url方法详解
Oct 26 Python
如何利用Python分析出微信朋友男女统计图
Jan 25 Python
Python3实现从排序数组中删除重复项算法分析
Apr 03 Python
python绘制彩虹图
Dec 16 Python
python实现跨excel sheet复制代码实例
Mar 03 Python
如何将PySpark导入Python的放实现(2种)
Apr 26 Python
使用matlab 判断两个矩阵是否相等的实例
May 11 Python
Pycharm github配置实现过程图解
Oct 13 Python
如何使用 Flask 做一个评论系统
Nov 27 Python
pandas数据分组groupby()和统计函数agg()的使用
Mar 04 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
PHP - Html Transfer Code
2006/10/09 PHP
用php实现选择排序的解决方法
2013/05/04 PHP
php的declare控制符和ticks教程(附示例)
2014/03/21 PHP
Laravel 5 框架入门(三)
2015/04/09 PHP
PHP实现合并discuz用户
2015/08/05 PHP
POST一个JSON格式的数据给Restful服务实例详解
2017/04/07 PHP
javascript中获取选中对象的类型
2007/04/02 Javascript
网页禁用右键实现代码(JavaScript代码)
2009/10/29 Javascript
jQuery 遍历json数组的实现代码
2020/09/22 Javascript
jquery和javascript中如何将一元素的内容赋给另一元素
2014/01/09 Javascript
JS+CSS实现简易实用的滑动门菜单效果
2015/09/18 Javascript
浅谈JS中的三种字符串连接方式及其性能比较
2016/09/02 Javascript
十大 Node.js 的 Web 框架(快速提升工作效率)
2017/06/30 Javascript
浅谈angular.js跨域post解决方案
2017/08/30 Javascript
nodejs 最新版安装npm 的使用详解
2018/01/18 NodeJs
angularJs提交文本框数据到后台的方法
2018/10/08 Javascript
vue项目环境变量配置的实现方法
2018/10/12 Javascript
Layui Table js 模拟选中checkbox的例子
2019/09/03 Javascript
vue简单练习 桌面时钟的实现代码实例
2019/09/19 Javascript
浅析Python中signal包的使用
2015/11/13 Python
利用python获取某年中每个月的第一天和最后一天
2016/12/15 Python
插入排序_Python与PHP的实现版(推荐)
2017/05/11 Python
python网络爬虫之如何伪装逃过反爬虫程序的方法
2017/11/23 Python
Python实现按当前日期(年、月、日)创建多级目录的方法
2018/04/26 Python
python实现远程控制电脑
2019/05/23 Python
使用Python给头像加上圣诞帽或圣诞老人小图标附源码
2019/12/25 Python
Redbubble法国:由独立艺术家设计的独特产品
2019/01/08 全球购物
Boolean b = new Boolean(“abcde”); 会编译错误码
2013/11/27 面试题
大学生找工作推荐信范文
2013/11/28 职场文书
家长会学生家长演讲稿
2013/12/29 职场文书
《中彩那天》教学反思
2014/02/22 职场文书
酒店开业庆典主持词
2014/03/21 职场文书
春节联欢会策划方案
2014/05/16 职场文书
庆祝教师节标语
2014/10/09 职场文书
2015年党员自评材料
2014/12/17 职场文书
60条职场经典语录,总有一条能触动你的心
2019/08/21 职场文书