keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python循环语句中else的用法总结
Sep 11 Python
Python连接phoenix的方法示例
Sep 29 Python
pandas数据处理基础之筛选指定行或者指定列的数据
May 03 Python
使用Python3+PyQT5+Pyserial 实现简单的串口工具方法
Feb 13 Python
python使用adbapi实现MySQL数据库的异步存储
Mar 19 Python
java中的控制结构(if,循环)详解
Jun 26 Python
Django Rest framework认证组件详细用法
Jul 25 Python
使用Django搭建web服务器的例子(最最正确的方式)
Aug 29 Python
python GUI库图形界面开发之PyQt5 UI主线程与耗时线程分离详细方法实例
Feb 26 Python
Python生成器next方法和send方法区别详解
May 30 Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 Python
python实现简单文件读写函数
Feb 25 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
使用php shell命令合并图片的代码
2011/06/23 PHP
ThinkPHP写数组插入与获取最新插入数据ID实例
2014/11/03 PHP
PHP通过反射动态加载第三方类和获得类源码的实例
2015/11/27 PHP
功能强大的php分页函数
2016/07/20 PHP
py文件转exe时包含paramiko模块出错解决方法
2016/08/12 PHP
php is_executable判断给定文件名是否可执行实例
2016/09/26 PHP
全面解析PHP面向对象的三大特征
2017/06/10 PHP
改善用户体验的五款jQuery插件分享
2011/05/22 Javascript
页面右下角弹出提示框示例代码js版
2013/08/02 Javascript
在JavaScript中实现类的方式探讨
2013/08/28 Javascript
js拖拽一些常见的思路方法整理
2014/03/19 Javascript
jquery实现表格本地排序的方法
2015/03/11 Javascript
Javascript将数字转化成为货币格式字符串
2016/06/22 Javascript
JS实现的随机排序功能算法示例
2017/06/09 Javascript
纯html+css+javascript实现楼层跳跃式的页面布局(实例代码)
2017/10/25 Javascript
Angular如何在应用初始化时运行代码详解
2018/06/11 Javascript
对Vue2 自定义全局指令Vue.directive和指令的生命周期介绍
2018/08/30 Javascript
微信小程序事件流原理解析
2019/11/27 Javascript
JavaScript将数组转换为链表的方法
2020/02/16 Javascript
详解node和ES6的模块导出与导入
2020/02/19 Javascript
vue-cli3中配置alias和打包加hash值操作
2020/09/04 Javascript
简单理解Python中的装饰器
2015/07/31 Python
python运行时间的几种方法
2016/06/17 Python
Django实现的自定义访问日志模块示例
2017/06/23 Python
Python Pandas 转换unix时间戳方式
2019/12/07 Python
Python 实现黑客帝国中的字符雨的示例代码
2020/02/20 Python
世界上最大的二手相机店:KEN
2017/05/17 全球购物
艺术专业大学生自我评价
2013/09/22 职场文书
师范生实习的个人自我鉴定
2013/10/20 职场文书
创意婚礼策划方案
2014/05/18 职场文书
总经理岗位职责
2015/02/04 职场文书
班主任远程培训研修日志
2015/11/13 职场文书
关于感恩的素材句子(38句)
2019/11/11 职场文书
react中props 的使用及进行限制的方法
2021/04/28 Javascript
MySQL update set 和 and的区别
2021/05/08 MySQL
Python集合set()使用的方法详解
2022/03/18 Python