keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用cookielib库示例分享
Mar 03 Python
python元组操作实例解析
Sep 23 Python
各种Python库安装包下载地址与安装过程详细介绍(Windows版)
Nov 02 Python
使用Python+Splinter自动刷新抢12306火车票
Jan 03 Python
Python实现string字符串连接的方法总结【8种方式】
Jul 06 Python
对IPython交互模式下的退出方法详解
Feb 16 Python
对python3 Serial 串口助手的接收读取数据方法详解
Jun 12 Python
pyqt5 lineEdit设置密码隐藏,删除lineEdit已输入的内容等属性方法
Jun 24 Python
Spring Cloud Feign高级应用实例详解
Dec 10 Python
python的Jenkins接口调用方式
May 12 Python
python 基于opencv操作摄像头
Dec 24 Python
2021年值得向Python开发者推荐的VS Code扩展插件
Jan 25 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
thinkphp ajaxfileupload实现异步上传图片的示例
2017/08/28 PHP
javascript 动态添加事件代码
2008/11/30 Javascript
分享XmlHttpRequest调用Webservice的一点心得
2012/07/20 Javascript
js 获取(接收)地址栏参数值的方法
2013/04/01 Javascript
浅谈node.js中async异步编程
2015/10/22 Javascript
mac下的nodejs环境安装的步骤
2017/05/24 NodeJs
vue2.0 computed 计算list循环后累加值的实例
2018/03/07 Javascript
详解JS取出两个数组中的不同或相同元素
2019/03/20 Javascript
运用js实现图层拖拽的功能
2019/05/24 Javascript
javascript实现抢购倒计时程序
2019/08/26 Javascript
解决layui 表单元素radio不显示渲染的问题
2019/09/04 Javascript
layui type2 通过url给iframe子页面传值的例子
2019/09/06 Javascript
JS实现TITLE悬停长久显示效果完整示例
2020/02/11 Javascript
jquery实现抽奖功能
2020/10/22 jQuery
小程序实现上下切换位置
2020/11/16 Javascript
详解JavaScript原型与原型链
2020/11/16 Javascript
Vue实现todo应用的示例
2021/02/20 Vue.js
[32:56]完美世界DOTA2联赛PWL S3 Rebirth vs CPG 第二场 12.11
2020/12/16 DOTA
python django 实现验证码的功能实例代码
2017/05/18 Python
Python3中使用PyMongo的方法详解
2017/07/28 Python
使用python为mysql实现restful接口
2018/01/05 Python
使用numpy和PIL进行简单的图像处理方法
2018/07/02 Python
总结python中pass的作用
2019/02/27 Python
详解python websocket获取实时数据的几种常见链接方式
2019/07/01 Python
tensorflow实现tensor中满足某一条件的数值取出组成新的tensor
2020/01/04 Python
Python matplotlib模块及柱状图用法解析
2020/08/10 Python
用HTML5制作烟火效果的教程
2015/05/12 HTML / CSS
美的官方商城:Midea
2016/09/14 全球购物
美国婚礼装饰和活动用品批发供应商:Event Decor Direct
2018/10/12 全球购物
英国排名第一的停车场运营商:NCP
2019/08/26 全球购物
欧洲最大的预定车位市场:JustPark
2020/01/06 全球购物
教师绩效考核方案
2014/01/21 职场文书
经济信息系毕业生自荐信
2014/06/02 职场文书
街道社区活动报告
2015/02/05 职场文书
老公婚前保证书
2015/02/28 职场文书
输入框跟随文字内容适配宽实现示例
2022/08/14 Javascript