详解tensorflow之过拟合问题实战


Posted in Python onNovember 01, 2020

过拟合问题实战

1.构建数据集

我们使用的数据集样本特性向量长度为 2,标签为 0 或 1,分别代表了 2 种类别。借助于 scikit-learn 库中提供的 make_moons 工具我们可以生成任意多数据的训练集。

import matplotlib.pyplot as plt
# 导入数据集生成工具
import numpy as np
import seaborn as sns
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, Sequential, regularizers
from mpl_toolkits.mplot3d import Axes3D

为了演示过拟合现象,我们只采样了 1000 个样本数据,同时添加标准差为 0.25 的高斯噪声数据:

def load_dataset():
 # 采样点数
 N_SAMPLES = 1000
 # 测试数量比率
 TEST_SIZE = None

 # 从 moon 分布中随机采样 1000 个点,并切分为训练集-测试集
 X, y = make_moons(n_samples=N_SAMPLES, noise=0.25, random_state=100)
 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)
 return X, y, X_train, X_test, y_train, y_test

make_plot 函数可以方便地根据样本的坐标 X 和样本的标签 y 绘制出数据的分布图:

def make_plot(X, y, plot_name, file_name, XX=None, YY=None, preds=None, dark=False, output_dir=OUTPUT_DIR):
 # 绘制数据集的分布, X 为 2D 坐标, y 为数据点的标签
 if dark:
  plt.style.use('dark_background')
 else:
  sns.set_style("whitegrid")
 axes = plt.gca()
 axes.set_xlim([-2, 3])
 axes.set_ylim([-1.5, 2])
 axes.set(xlabel="$x_1$", ylabel="$x_2$")
 plt.title(plot_name, fontsize=20, fontproperties='SimHei')
 plt.subplots_adjust(left=0.20)
 plt.subplots_adjust(right=0.80)
 if XX is not None and YY is not None and preds is not None:
  plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha=0.08, cmap=plt.cm.Spectral)
  plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap="Greys", vmin=0, vmax=.6)
 # 绘制散点图,根据标签区分颜色m=markers
 markers = ['o' if i == 1 else 's' for i in y.ravel()]
 mscatter(X[:, 0], X[:, 1], c=y.ravel(), s=20, cmap=plt.cm.Spectral, edgecolors='none', m=markers, ax=axes)
 # 保存矢量图
 plt.savefig(output_dir + '/' + file_name)
 plt.close()
def mscatter(x, y, ax=None, m=None, **kw):
 import matplotlib.markers as mmarkers
 if not ax: ax = plt.gca()
 sc = ax.scatter(x, y, **kw)
 if (m is not None) and (len(m) == len(x)):
  paths = []
  for marker in m:
   if isinstance(marker, mmarkers.MarkerStyle):
    marker_obj = marker
   else:
    marker_obj = mmarkers.MarkerStyle(marker)
   path = marker_obj.get_path().transformed(
    marker_obj.get_transform())
   paths.append(path)
  sc.set_paths(paths)
 return sc
X, y, X_train, X_test, y_train, y_test = load_dataset()
make_plot(X,y,"haha",'月牙形状二分类数据集分布.svg')

详解tensorflow之过拟合问题实战

2.网络层数的影响

为了探讨不同的网络深度下的过拟合程度,我们共进行了 5 次训练实验。在? ∈ [0,4]时,构建网络层数为n + 2层的全连接层网络,并通过 Adam 优化器训练 500 个 Epoch

def network_layers_influence(X_train, y_train):
 # 构建 5 种不同层数的网络
 for n in range(5):
  # 创建容器
  model = Sequential()
  # 创建第一层
  model.add(layers.Dense(8, input_dim=2, activation='relu'))
  # 添加 n 层,共 n+2 层
  for _ in range(n):
   model.add(layers.Dense(32, activation='relu'))
  # 创建最末层
  model.add(layers.Dense(1, activation='sigmoid'))
  # 模型装配与训练
  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  model.fit(X_train, y_train, epochs=N_EPOCHS, verbose=1)
  # 绘制不同层数的网络决策边界曲线
  # 可视化的 x 坐标范围为[-2, 3]
  xx = np.arange(-2, 3, 0.01)
  # 可视化的 y 坐标范围为[-1.5, 2]
  yy = np.arange(-1.5, 2, 0.01)
  # 生成 x-y 平面采样网格点,方便可视化
  XX, YY = np.meshgrid(xx, yy)
  preds = model.predict_classes(np.c_[XX.ravel(), YY.ravel()])
  print(preds)
  title = "网络层数:{0}".format(2 + n)
  file = "网络容量_%i.png" % (2 + n)
  make_plot(X_train, y_train, title, file, XX, YY, preds, output_dir=OUTPUT_DIR + '/network_layers')

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

3.Dropout的影响

为了探讨 Dropout 层对网络训练的影响,我们共进行了 5 次实验,每次实验使用 7 层的全连接层网络进行训练,但是在全连接层中间隔插入 0~4 个 Dropout 层并通过 Adam优化器训练 500 个 Epoch

def dropout_influence(X_train, y_train):
 # 构建 5 种不同数量 Dropout 层的网络
 for n in range(5):
  # 创建容器
  model = Sequential()
  # 创建第一层
  model.add(layers.Dense(8, input_dim=2, activation='relu'))
  counter = 0
  # 网络层数固定为 5
  for _ in range(5):
   model.add(layers.Dense(64, activation='relu'))
  # 添加 n 个 Dropout 层
   if counter < n:
    counter += 1
    model.add(layers.Dropout(rate=0.5))

  # 输出层
  model.add(layers.Dense(1, activation='sigmoid'))
  # 模型装配
  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  # 训练
  model.fit(X_train, y_train, epochs=N_EPOCHS, verbose=1)
  # 绘制不同 Dropout 层数的决策边界曲线
  # 可视化的 x 坐标范围为[-2, 3]
  xx = np.arange(-2, 3, 0.01)
  # 可视化的 y 坐标范围为[-1.5, 2]
  yy = np.arange(-1.5, 2, 0.01)
  # 生成 x-y 平面采样网格点,方便可视化
  XX, YY = np.meshgrid(xx, yy)
  preds = model.predict_classes(np.c_[XX.ravel(), YY.ravel()])
  title = "无Dropout层" if n == 0 else "{0}层 Dropout层".format(n)
  file = "Dropout_%i.png" % n
  make_plot(X_train, y_train, title, file, XX, YY, preds, output_dir=OUTPUT_DIR + '/dropout')

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

4.正则化的影响

为了探讨正则化系数?对网络模型训练的影响,我们采用 L2 正则化方式,构建了 5 层的神经网络,其中第 2,3,4 层神经网络层的权值张量 W 均添加 L2 正则化约束项:

def build_model_with_regularization(_lambda):
 # 创建带正则化项的神经网络
 model = Sequential()
 model.add(layers.Dense(8, input_dim=2, activation='relu')) # 不带正则化项
 # 2-4层均是带 L2 正则化项
 model.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(_lambda)))
 model.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(_lambda)))
 model.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(_lambda)))
 # 输出层
 model.add(layers.Dense(1, activation='sigmoid'))
 model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # 模型装配
 return model

下面我们首先来实现一个权重可视化的函数

def plot_weights_matrix(model, layer_index, plot_name, file_name, output_dir=OUTPUT_DIR):
 # 绘制权值范围函数
 # 提取指定层的权值矩阵
 weights = model.layers[layer_index].get_weights()[0]
 shape = weights.shape
 # 生成和权值矩阵等大小的网格坐标
 X = np.array(range(shape[1]))
 Y = np.array(range(shape[0]))
 X, Y = np.meshgrid(X, Y)
 # 绘制3D图
 fig = plt.figure()
 ax = fig.gca(projection='3d')
 ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
 ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
 ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
 plt.title(plot_name, fontsize=20, fontproperties='SimHei')
 # 绘制权值矩阵范围
 ax.plot_surface(X, Y, weights, cmap=plt.get_cmap('rainbow'), linewidth=0)
 # 设置坐标轴名
 ax.set_xlabel('网格x坐标', fontsize=16, rotation=0, fontproperties='SimHei')
 ax.set_ylabel('网格y坐标', fontsize=16, rotation=0, fontproperties='SimHei')
 ax.set_zlabel('权值', fontsize=16, rotation=90, fontproperties='SimHei')
 # 保存矩阵范围图
 plt.savefig(output_dir + "/" + file_name + ".svg")
 plt.close(fig)

在保持网络结构不变的条件下,我们通过调节正则化系数 ? = 0.00001,0.001,0.1,0.12,0.13 来测试网络的训练效果,并绘制出学习模型在训练集上的决策边界曲线

def regularizers_influence(X_train, y_train):
 for _lambda in [1e-5, 1e-3, 1e-1, 0.12, 0.13]: # 设置不同的正则化系数
  # 创建带正则化项的模型
  model = build_model_with_regularization(_lambda)
  # 模型训练
  model.fit(X_train, y_train, epochs=N_EPOCHS, verbose=1)
  # 绘制权值范围
  layer_index = 2
  plot_title = "正则化系数:{}".format(_lambda)
  file_name = "正则化网络权值_" + str(_lambda)
  # 绘制网络权值范围图
  plot_weights_matrix(model, layer_index, plot_title, file_name, output_dir=OUTPUT_DIR + '/regularizers')
  # 绘制不同正则化系数的决策边界线
  # 可视化的 x 坐标范围为[-2, 3]
  xx = np.arange(-2, 3, 0.01)
  # 可视化的 y 坐标范围为[-1.5, 2]
  yy = np.arange(-1.5, 2, 0.01)
  # 生成 x-y 平面采样网格点,方便可视化
  XX, YY = np.meshgrid(xx, yy)
  preds = model.predict_classes(np.c_[XX.ravel(), YY.ravel()])
  title = "正则化系数:{}".format(_lambda)
  file = "正则化_%g.svg" % _lambda
  make_plot(X_train, y_train, title, file, XX, YY, preds, output_dir=OUTPUT_DIR + '/regularizers')
regularizers_influence(X_train, y_train)

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

详解tensorflow之过拟合问题实战

到此这篇关于详解tensorflow之过拟合问题实战的文章就介绍到这了,更多相关tensorflow 过拟合内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现在IDLE中输入多行的方法
Apr 19 Python
Matplotlib中文乱码的3种解决方案
Nov 15 Python
python数据预处理之数据标准化的几种处理方式
Jul 17 Python
django框架CSRF防护原理与用法分析
Jul 22 Python
python 接口实现 供第三方调用的例子
Aug 13 Python
Django如何实现网站注册用户邮箱验证功能
Aug 14 Python
python GUI库图形界面开发之PyQt5多线程中信号与槽的详细使用方法与实例
Mar 08 Python
在python中求分布函数相关的包实例
Apr 15 Python
MAC平台基于Python Appium环境搭建过程图解
Aug 13 Python
基于PyInstaller各参数的含义说明
Mar 04 Python
python中requests库+xpath+lxml简单使用
Apr 29 Python
anaconda python3.8安装后降级
Jun 11 Python
python cookie反爬处理的实现
Nov 01 #Python
10个python爬虫入门实例(小结)
Nov 01 #Python
利用pipenv和pyenv管理多个相互独立的Python虚拟开发环境
Nov 01 #Python
Python经纬度坐标转换为距离及角度的实现
Nov 01 #Python
详解Anaconda安装tensorflow报错问题解决方法
Nov 01 #Python
python Cartopy的基础使用详解
Nov 01 #Python
Python中使用aiohttp模拟服务器出现错误问题及解决方法
Oct 31 #Python
You might like
php将数据库导出成excel的方法
2010/05/07 PHP
简单实现限定phpmyadmin访问ip的方法
2013/03/05 PHP
二进制交叉权限微型php类分享
2014/02/07 PHP
PHP变量赋值、代入给JavaScript中的变量
2015/06/29 PHP
是 WordPress 让 PHP 更流行了 而不是框架
2016/02/03 PHP
laravel-admin 在列表页添加自定义按钮的例子
2019/09/30 PHP
JS array 数组详解
2009/03/22 Javascript
js实现广告漂浮效果的小例子
2013/07/02 Javascript
js/html光标定位的实现代码
2013/09/23 Javascript
JavaScript判断访问的来源是手机还是电脑,用的哪种浏览器
2013/12/12 Javascript
BootStrap.css 在手机端滑动时右侧出现空白的原因及解决办法
2016/06/07 Javascript
js replace()去除代码中空格的实例
2017/02/14 Javascript
js实现手机发送验证码功能
2017/03/13 Javascript
vue.js中过滤器的使用教程
2017/06/08 Javascript
使用Vue组件实现一个简单弹窗效果
2018/04/23 Javascript
Vue2.0使用嵌套路由实现页面内容切换/公用一级菜单控制页面内容切换(推荐)
2019/05/08 Javascript
CKeditor富文本编辑器使用技巧之添加自定义插件的方法
2019/06/14 Javascript
Layui之table中的radio在切换分页时无法记住选中状态的解决方法
2019/09/02 Javascript
extjs4图表绘制之折线图实现方法分析
2020/03/06 Javascript
学习python中matplotlib绘图设置坐标轴刻度、文本
2018/02/07 Python
怎么使用pipenv管理你的python项目
2018/03/12 Python
Python将DataFrame的某一列作为index的方法
2018/04/08 Python
Python使用Pickle库实现读写序列操作示例
2018/06/15 Python
Python OpenCV实现视频分帧
2019/06/01 Python
Django-Model数据库操作(增删改查、连表结构)详解
2019/07/17 Python
pandas 选取行和列数据的方法详解
2019/08/08 Python
Python内置函数property()如何使用
2020/09/01 Python
Pycharm Plugins加载失败问题解决方案
2020/11/28 Python
基于第一个PhoneGap(cordova)的应用详解
2013/05/03 HTML / CSS
新西兰最大的在线设计师眼镜店:SmartBuyGlasses新西兰
2017/10/20 全球购物
墨西哥网上超市:Superama
2018/07/10 全球购物
八一建军节营销活动方案
2014/08/31 职场文书
餐厅服务员岗位职责
2015/02/09 职场文书
2015年超市员工工作总结
2015/05/04 职场文书
Python使用华为API为图像设置多个锚点标签
2022/04/12 Python
MySQL中的 inner join 和 left join的区别解析(小结果集驱动大结果集)
2023/05/08 MySQL