如何用Python 实现全连接神经网络(Multi-layer Perceptron)


Posted in Python onOctober 15, 2020

代码

import numpy as np

# 各种激活函数及导数
def sigmoid(x):
  return 1 / (1 + np.exp(-x))


def dsigmoid(y):
  return y * (1 - y)


def tanh(x):
  return np.tanh(x)


def dtanh(y):
  return 1.0 - y ** 2


def relu(y):
  tmp = y.copy()
  tmp[tmp < 0] = 0
  return tmp


def drelu(x):
  tmp = x.copy()
  tmp[tmp >= 0] = 1
  tmp[tmp < 0] = 0
  return tmp


class MLPClassifier(object):
  """多层感知机,BP 算法训练"""

  def __init__(self,
         layers,
         activation='tanh',
         epochs=20, batch_size=1, learning_rate=0.01):
    """
    :param layers: 网络层结构
    :param activation: 激活函数
    :param epochs: 迭代轮次
    :param learning_rate: 学习率 
    """
    self.epochs = epochs
    self.learning_rate = learning_rate
    self.layers = []
    self.weights = []
    self.batch_size = batch_size

    for i in range(0, len(layers) - 1):
      weight = np.random.random((layers[i], layers[i + 1]))
      layer = np.ones(layers[i])
      self.layers.append(layer)
      self.weights.append(weight)
    self.layers.append(np.ones(layers[-1]))

    self.thresholds = []
    for i in range(1, len(layers)):
      threshold = np.random.random(layers[i])
      self.thresholds.append(threshold)

    if activation == 'tanh':
      self.activation = tanh
      self.dactivation = dtanh
    elif activation == 'sigomid':
      self.activation = sigmoid
      self.dactivation = dsigmoid
    elif activation == 'relu':
      self.activation = relu
      self.dactivation = drelu

  def fit(self, X, y):
    """
    :param X_: shape = [n_samples, n_features] 
    :param y: shape = [n_samples] 
    :return: self
    """
    for _ in range(self.epochs * (X.shape[0] // self.batch_size)):
      i = np.random.choice(X.shape[0], self.batch_size)
      # i = np.random.randint(X.shape[0])
      self.update(X[i])
      self.back_propagate(y[i])

  def predict(self, X):
    """
    :param X: shape = [n_samples, n_features] 
    :return: shape = [n_samples]
    """
    self.update(X)
    return self.layers[-1].copy()

  def update(self, inputs):
    self.layers[0] = inputs
    for i in range(len(self.weights)):
      next_layer_in = self.layers[i] @ self.weights[i] - self.thresholds[i]
      self.layers[i + 1] = self.activation(next_layer_in)

  def back_propagate(self, y):
    errors = y - self.layers[-1]

    gradients = [(self.dactivation(self.layers[-1]) * errors).sum(axis=0)]

    self.thresholds[-1] -= self.learning_rate * gradients[-1]
    for i in range(len(self.weights) - 1, 0, -1):
      tmp = np.sum(gradients[-1] @ self.weights[i].T * self.dactivation(self.layers[i]), axis=0)
      gradients.append(tmp)
      self.thresholds[i - 1] -= self.learning_rate * gradients[-1] / self.batch_size
    gradients.reverse()
    for i in range(len(self.weights)):
      tmp = np.mean(self.layers[i], axis=0)
      self.weights[i] += self.learning_rate * tmp.reshape((-1, 1)) * gradients[i]

测试代码

import sklearn.datasets
import numpy as np

def plot_decision_boundary(pred_func, X, y, title=None):
  """分类器画图函数,可画出样本点和决策边界
  :param pred_func: predict函数
  :param X: 训练集X
  :param y: 训练集Y
  :return: None
  """

  # Set min and max values and give it some padding
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # Generate a grid of points with distance h between them
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # Predict the function value for the whole gid
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # Plot the contour and training examples
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)

  if title:
    plt.title(title)
  plt.show()


def test_mlp():
  X, y = sklearn.datasets.make_moons(200, noise=0.20)
  y = y.reshape((-1, 1))
  n = MLPClassifier((2, 3, 1), activation='tanh', epochs=300, learning_rate=0.01)
  n.fit(X, y)
  def tmp(X):
    sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)
    ans = sign(n.predict(X))
    return ans

  plot_decision_boundary(tmp, X, y, 'Neural Network')

效果

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是如何用Python 实现全连接神经网络(Multi-layer Perceptron)的详细内容,更多关于Python 实现全连接神经网络的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
用python实现的去除win下文本文件头部BOM的代码
Feb 10 Python
Python删除windows垃圾文件的方法
Jul 14 Python
常用python编程模板汇总
Feb 12 Python
Python 迭代器工具包【推荐】
May 06 Python
解决Python selenium get页面很慢时的问题
Jan 30 Python
Python pandas.DataFrame调整列顺序及修改index名的方法
Jun 21 Python
python基于递归解决背包问题详解
Jul 03 Python
pytorch模型预测结果与ndarray互转方式
Jan 15 Python
Xadmin+rules实现多选行权限方式(级联效果)
Apr 07 Python
python生成xml时规定dtd实例方法
Sep 21 Python
scrapy-splash简单使用详解
Feb 21 Python
Python爬虫 简单介绍一下Xpath及使用
Apr 26 Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
python实现粒子群算法
Oct 15 #Python
如何将anaconda安装配置的mmdetection环境离线拷贝到另一台电脑
Oct 15 #Python
Python3.7安装PyQt5 运行配置Pycharm的详细教程
Oct 15 #Python
python利用faker库批量生成测试数据
Oct 15 #Python
如何利用python检测图片是否包含二维码
Oct 15 #Python
You might like
php checkdate、getdate等日期时间函数操作详解
2010/03/11 PHP
php输出表格的实现代码(修正版)
2010/12/29 PHP
30 个很棒的PHP开源CMS内容管理系统小结
2011/10/14 PHP
PHP 文件编程综合案例-文件上传的实现
2013/07/03 PHP
Symfony2安装第三方Bundles实例详解
2016/02/04 PHP
简单的自定义php模板引擎
2016/08/26 PHP
PDO::rollBack讲解
2019/01/29 PHP
ExtJS 2.0实用简明教程 之获得ExtJS
2009/04/29 Javascript
jquery UI 1.72 之datepicker
2009/12/29 Javascript
jQuery toggleClass应用实例(附效果图)
2014/04/06 Javascript
原生javascript实现DIV拖拽并计算重复面积
2015/01/02 Javascript
jQuery与getJson结合的用法实例
2015/08/07 Javascript
纯javascript判断查询日期是否为有效日期
2015/08/24 Javascript
jQuery实现点击按钮弹出可关闭层的浮动层插件
2015/09/19 Javascript
解析jquery easyui tree异步加载子节点问题
2017/03/08 Javascript
详解React项目如何修改打包地址(编译输出文件地址)
2019/03/21 Javascript
Vue开发环境中修改端口号的实现方法
2019/08/15 Javascript
[02:16]2018年度CS GO最具人气选手-完美盛典
2018/12/16 DOTA
Python异常对代码运行性能的影响实例解析
2018/02/08 Python
在pandas中一次性删除dataframe的多个列方法
2018/04/10 Python
win7 x64系统中安装Scrapy的方法
2018/11/18 Python
Python3爬虫学习入门教程
2018/12/11 Python
Django+Xadmin构建项目的方法步骤
2019/03/06 Python
Python爬取智联招聘数据分析师岗位相关信息的方法
2019/08/13 Python
OpenCV模板匹配matchTemplate的实现
2019/10/18 Python
Pycharm2020最新激活码|永久激活(附最新激活码和插件的详细教程)
2020/09/29 Python
PHP如何删除一个Cookie值
2012/11/15 面试题
质检的岗位职责
2013/11/17 职场文书
最新的互联网创业计划书
2014/01/10 职场文书
硕士研究生求职自荐信范文
2014/03/11 职场文书
2014大学班主任工作总结
2014/11/08 职场文书
餐饮服务员岗位职责
2015/02/09 职场文书
2015年技术工作总结范文
2015/04/20 职场文书
开学第一周总结
2015/07/16 职场文书
React实现动效弹窗组件
2021/06/21 Javascript
CSS实现切角+边框+投影+内容背景色渐变效果
2021/11/01 HTML / CSS