keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
高质量Python代码编写的5个优化技巧
Nov 16 Python
Python格式化输出%s和%d
May 07 Python
python使用turtle库与random库绘制雪花
Jun 22 Python
python设置环境变量的原因和方法
Jun 24 Python
用Python从0开始实现一个中文拼音输入法的思路详解
Jul 20 Python
Python 调用 Windows API COM 新法
Aug 22 Python
python3反转字符串的3种方法(小结)
Nov 07 Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 Python
Python中私有属性的定义方式
Mar 05 Python
用python绘制樱花树
Oct 09 Python
Python3中的tuple函数知识点讲解
Jan 03 Python
详解python第三方库的安装、PyInstaller库、random库
Mar 03 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
PHP为表单获取的URL 地址预设 http 字符串函数代码
2010/05/26 PHP
php写的带缓存数据功能的mysqli类
2012/09/06 PHP
在PHP模板引擎smarty生成随机数的方法和math函数详解
2014/04/24 PHP
Windows2003下php5.4安装配置教程(IIS)
2016/06/30 PHP
PHP有序表查找之插值查找算法示例
2018/02/10 PHP
PHP基于面向对象封装的分页类示例
2019/03/15 PHP
基于PHP+mysql实现新闻发布系统的开发
2020/08/06 PHP
Jquery 设置标题的自动翻转
2009/10/03 Javascript
Javascript 鼠标移动上去 滑块跟随效果代码分享
2013/11/23 Javascript
jQuery .tmpl() 用法示例介绍
2014/08/21 Javascript
jQuery焦点图轮播插件KinSlideshow用法分析
2016/06/08 Javascript
概述一个页面从输入URL到页面加载完的过程
2016/12/16 Javascript
完美解决JS文件页面加载时的阻塞问题
2016/12/18 Javascript
原生ajax处理json格式数据的实例代码
2016/12/25 Javascript
jquery实现图片轮播器
2017/05/23 jQuery
解决vue-quill-editor上传内容由于图片是base64的导致字符太长的问题
2018/08/20 Javascript
js中实例与对象的区别讲解
2019/01/21 Javascript
详解vue移动端项目代码拆分记录
2019/03/15 Javascript
JS表格的动态操作完整示例
2020/01/13 Javascript
[03:54]DOTA2英雄梦之声_第06期_昆卡
2014/06/23 DOTA
python字符串替换示例
2014/04/24 Python
Python制作简易注册登录系统
2016/12/15 Python
全面了解Nginx, WSGI, Flask之间的关系
2018/01/09 Python
python将字符串转变成dict格式的实现
2019/11/18 Python
使用Keras加载含有自定义层或函数的模型操作
2020/06/10 Python
HTML中fieldset标签概述及使用方法
2013/02/01 HTML / CSS
HTML5 预加载让页面得以快速呈现
2013/08/13 HTML / CSS
日本酒店、民宿、温泉旅馆、当地旅行团中文预订:e路东瀛
2019/12/09 全球购物
如何提高JDBC的性能
2013/04/30 面试题
什么造成了Java里面的异常
2016/04/24 面试题
建筑设计师岗位职责
2013/11/18 职场文书
旅游管理专业大学生职业规划书
2014/02/27 职场文书
孝敬父母的活动方案
2014/08/28 职场文书
2014年学生会工作总结范文
2014/11/07 职场文书
民事诉讼代理词
2015/05/25 职场文书
参加招聘会后的感想
2015/08/10 职场文书