keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3使用urllib示例取googletranslate(谷歌翻译)
Jan 23 Python
Python中用max()方法求最大值的介绍
May 15 Python
在Django中使用Sitemap的方法讲解
Jul 22 Python
python 获取文件下所有文件或目录os.walk()的实例
Apr 23 Python
通过python顺序修改文件名字的方法
Jul 11 Python
python在TXT文件中按照某一字符串取出该字符串所在的行方法
Dec 10 Python
python爬取微信公众号文章的方法
Feb 26 Python
Python3 chardet模块查看编码格式的例子
Aug 14 Python
Python使用matplotlib 画矩形的三种方式分析
Oct 31 Python
Python run()函数和start()函数的比较和差别介绍
May 03 Python
python爬虫调度器用法及实例代码
Nov 30 Python
Python使用pandas导入csv文件内容的示例代码
Dec 24 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
PHP中一个控制字符串输出的函数
2006/10/09 PHP
使用session判断用户登录用户权限(超简单)
2013/06/08 PHP
thinkPHP下ueditor的使用方法详解
2015/12/26 PHP
提高网站信任度的技巧
2008/10/17 Javascript
Uglifyjs(JS代码优化工具)入门 安装使用
2020/04/13 Javascript
jquery.qrcode在线生成二维码使用示例
2013/08/21 Javascript
js全屏显示显示代码的三种方法
2013/11/11 Javascript
JavaScript按位运算符的应用简析
2014/02/04 Javascript
js时钟翻牌效果实现代码分享
2020/07/31 Javascript
jquery实现动画菜单的左右滚动、渐变及图形背景滚动等效果
2015/08/25 Javascript
深入探讨javascript函数式编程
2015/10/11 Javascript
jQuery实现表格行和列的动态添加与删除方法【测试可用】
2016/08/01 Javascript
jQuery实现的自动加载页面功能示例
2016/09/04 Javascript
Vue 进阶教程之v-model详解
2017/05/06 Javascript
vue2.0 自定义日期时间过滤器
2017/06/07 Javascript
详解webpack和webpack-simple中如何引入css文件
2017/06/28 Javascript
AJAX在JQuery中的应用详解
2019/01/30 jQuery
浅谈Express.js解析Post数据类型的正确姿势
2019/05/30 Javascript
python实现DNS正向查询、反向查询的例子
2014/04/25 Python
MySQL中表的复制以及大型数据表的备份教程
2015/11/25 Python
对python中的for循环和range内置函数详解
2018/04/17 Python
详谈python3 numpy-loadtxt的编码问题
2018/04/29 Python
浅谈pycharm导入pandas包遇到的问题及解决
2020/06/01 Python
KLOOK客路:发现更好玩的世界,预订独一无二的旅行体验
2016/12/16 全球购物
Otel.com:折扣酒店预订
2017/08/24 全球购物
通往英国高街的商店橱窗:Down Your High Street
2020/07/19 全球购物
Linux如何命名文件--使用文件名时应注意
2012/01/22 面试题
老师推荐信
2013/10/28 职场文书
共产党员承诺书
2014/03/25 职场文书
爱心倡议书范文
2014/05/12 职场文书
党员承诺书范文
2014/05/19 职场文书
人事任命书怎么写
2014/06/05 职场文书
学校党的群众路线教育实践活动领导班子对照检查材料
2014/09/25 职场文书
2015初中教导处工作总结
2015/07/21 职场文书
创业计划书之川味火锅店
2019/09/02 职场文书
oracle delete误删除表数据后如何恢复
2022/06/28 Oracle