keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python sort、sorted高级排序技巧
Nov 21 Python
Python multiprocessing模块中的Pipe管道使用实例
Apr 11 Python
九步学会Python装饰器
May 09 Python
使用Python的Flask框架构建大型Web应用程序的结构示例
Jun 04 Python
Python使用cookielib模块操作cookie的实例教程
Jul 12 Python
Python实现按学生年龄排序的实际问题详解
Aug 29 Python
Python实现一个Git日志统计分析的小工具
Dec 14 Python
python读取.mat文件的数据及实例代码
Jul 12 Python
pytorch 模型可视化的例子
Aug 17 Python
Python 进程操作之进程间通过队列共享数据,队列Queue简单示例
Oct 11 Python
Python实例方法、类方法、静态方法区别详解
Sep 05 Python
python自动统计zabbix系统监控覆盖率的示例代码
Apr 03 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
PHP生成静态页
2006/11/25 PHP
将文件夹压缩成zip文件的php代码
2009/12/14 PHP
基于php伪静态的实现详细介绍
2013/04/28 PHP
PHP MVC框架中类的自动加载机制实例分析
2019/09/18 PHP
jQuery toggle()设置CSS样式
2009/11/05 Javascript
javascript 计算两个整数的百分比值
2009/12/26 Javascript
用Jquery实现可编辑表格并用AJAX提交到服务器修改数据
2009/12/27 Javascript
javascript一个无懈可击的实例化XMLHttpRequest的方法
2010/10/13 Javascript
如何让DIV可编辑、可拖动示例代码
2013/09/18 Javascript
Array 重排序方法和操作方法的简单实例
2014/01/24 Javascript
Jquery Ajax方法传值到action的方法
2014/05/11 Javascript
jquery实现html页面 div 假分页有原理有代码
2014/09/06 Javascript
javascript实现简单的贪吃蛇游戏
2015/03/31 Javascript
JS实现日期时间动态显示的方法
2015/12/07 Javascript
基于jQuery实现发送短信验证码后的倒计时功能(无视页面关闭)
2016/09/02 Javascript
浅谈js中几种实用的跨域方法原理详解
2016/12/02 Javascript
vue-cli配置文件——config篇
2018/01/04 Javascript
vue  elementUI 表单嵌套验证的实例代码
2019/11/06 Javascript
记一次用ts+vuecli4重构项目的实现
2020/05/21 Javascript
详解JavaScript中new操作符的解析和实现
2020/09/04 Javascript
前端如何实现动画过渡效果
2021/02/05 Javascript
零基础写python爬虫之爬虫的定义及URL构成
2014/11/04 Python
python实现时间o(1)的最小栈的实例代码
2018/07/23 Python
详解Python爬取并下载《电影天堂》3千多部电影
2019/04/26 Python
Python正则表达式匹配数字和小数的方法
2019/07/03 Python
numpy求平均值的维度设定的例子
2019/08/24 Python
Python requests模块基础使用方法实例及高级应用(自动登陆,抓取网页源码)实例详解
2020/02/14 Python
获取邓白氏信用报告:Dun & Bradstreet
2019/01/22 全球购物
介绍一下如何优化MySql
2016/12/20 面试题
应用化学专业职业生涯规划书
2014/01/22 职场文书
医学生就业推荐表自我鉴定
2014/03/26 职场文书
我的中国心演讲稿
2014/09/04 职场文书
群众路线四风问题整改措施
2014/09/27 职场文书
2015年扫黄打非工作总结
2015/05/13 职场文书
同学聚会开幕词
2019/04/02 职场文书
Golang之sync.Pool使用详解
2021/05/06 Golang