如何用Python 实现全连接神经网络(Multi-layer Perceptron)


Posted in Python onOctober 15, 2020

代码

import numpy as np

# 各种激活函数及导数
def sigmoid(x):
  return 1 / (1 + np.exp(-x))


def dsigmoid(y):
  return y * (1 - y)


def tanh(x):
  return np.tanh(x)


def dtanh(y):
  return 1.0 - y ** 2


def relu(y):
  tmp = y.copy()
  tmp[tmp < 0] = 0
  return tmp


def drelu(x):
  tmp = x.copy()
  tmp[tmp >= 0] = 1
  tmp[tmp < 0] = 0
  return tmp


class MLPClassifier(object):
  """多层感知机,BP 算法训练"""

  def __init__(self,
         layers,
         activation='tanh',
         epochs=20, batch_size=1, learning_rate=0.01):
    """
    :param layers: 网络层结构
    :param activation: 激活函数
    :param epochs: 迭代轮次
    :param learning_rate: 学习率 
    """
    self.epochs = epochs
    self.learning_rate = learning_rate
    self.layers = []
    self.weights = []
    self.batch_size = batch_size

    for i in range(0, len(layers) - 1):
      weight = np.random.random((layers[i], layers[i + 1]))
      layer = np.ones(layers[i])
      self.layers.append(layer)
      self.weights.append(weight)
    self.layers.append(np.ones(layers[-1]))

    self.thresholds = []
    for i in range(1, len(layers)):
      threshold = np.random.random(layers[i])
      self.thresholds.append(threshold)

    if activation == 'tanh':
      self.activation = tanh
      self.dactivation = dtanh
    elif activation == 'sigomid':
      self.activation = sigmoid
      self.dactivation = dsigmoid
    elif activation == 'relu':
      self.activation = relu
      self.dactivation = drelu

  def fit(self, X, y):
    """
    :param X_: shape = [n_samples, n_features] 
    :param y: shape = [n_samples] 
    :return: self
    """
    for _ in range(self.epochs * (X.shape[0] // self.batch_size)):
      i = np.random.choice(X.shape[0], self.batch_size)
      # i = np.random.randint(X.shape[0])
      self.update(X[i])
      self.back_propagate(y[i])

  def predict(self, X):
    """
    :param X: shape = [n_samples, n_features] 
    :return: shape = [n_samples]
    """
    self.update(X)
    return self.layers[-1].copy()

  def update(self, inputs):
    self.layers[0] = inputs
    for i in range(len(self.weights)):
      next_layer_in = self.layers[i] @ self.weights[i] - self.thresholds[i]
      self.layers[i + 1] = self.activation(next_layer_in)

  def back_propagate(self, y):
    errors = y - self.layers[-1]

    gradients = [(self.dactivation(self.layers[-1]) * errors).sum(axis=0)]

    self.thresholds[-1] -= self.learning_rate * gradients[-1]
    for i in range(len(self.weights) - 1, 0, -1):
      tmp = np.sum(gradients[-1] @ self.weights[i].T * self.dactivation(self.layers[i]), axis=0)
      gradients.append(tmp)
      self.thresholds[i - 1] -= self.learning_rate * gradients[-1] / self.batch_size
    gradients.reverse()
    for i in range(len(self.weights)):
      tmp = np.mean(self.layers[i], axis=0)
      self.weights[i] += self.learning_rate * tmp.reshape((-1, 1)) * gradients[i]

测试代码

import sklearn.datasets
import numpy as np

def plot_decision_boundary(pred_func, X, y, title=None):
  """分类器画图函数,可画出样本点和决策边界
  :param pred_func: predict函数
  :param X: 训练集X
  :param y: 训练集Y
  :return: None
  """

  # Set min and max values and give it some padding
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # Generate a grid of points with distance h between them
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # Predict the function value for the whole gid
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # Plot the contour and training examples
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)

  if title:
    plt.title(title)
  plt.show()


def test_mlp():
  X, y = sklearn.datasets.make_moons(200, noise=0.20)
  y = y.reshape((-1, 1))
  n = MLPClassifier((2, 3, 1), activation='tanh', epochs=300, learning_rate=0.01)
  n.fit(X, y)
  def tmp(X):
    sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)
    ans = sign(n.predict(X))
    return ans

  plot_decision_boundary(tmp, X, y, 'Neural Network')

效果

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是如何用Python 实现全连接神经网络(Multi-layer Perceptron)的详细内容,更多关于Python 实现全连接神经网络的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python 不同对象比较大小示例探讨
Aug 21 Python
跟老齐学Python之永远强大的函数
Sep 14 Python
Python实现字符串格式化的方法小结
Feb 20 Python
TensorFlow安装及jupyter notebook配置方法
Sep 08 Python
对Python中Iterator和Iterable的区别详解
Oct 18 Python
Python实现二叉树的最小深度的两种方法
Sep 30 Python
python 采用paramiko 远程执行命令及报错解决
Oct 21 Python
python 检查数据中是否有缺失值,删除缺失值的方式
Dec 02 Python
python requests模拟登陆github的实现方法
Dec 26 Python
详解Python中pyautogui库的最全使用方法
Apr 01 Python
pymysql之cur.fetchall() 和cur.fetchone()用法详解
May 15 Python
python将YUV420P文件转PNG图片格式的两种方法
Jan 22 Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
python实现粒子群算法
Oct 15 #Python
如何将anaconda安装配置的mmdetection环境离线拷贝到另一台电脑
Oct 15 #Python
Python3.7安装PyQt5 运行配置Pycharm的详细教程
Oct 15 #Python
python利用faker库批量生成测试数据
Oct 15 #Python
如何利用python检测图片是否包含二维码
Oct 15 #Python
You might like
PHP编码规范-php coding standard
2007/03/16 PHP
简单介绍下 PHP5 中引入的 MYSQLI的用途
2007/03/19 PHP
php str_pad 函数用法简介
2009/07/11 PHP
php把数据表导出为Excel表的最简单、最快的方法(不用插件)
2014/05/10 PHP
PHP检测数据类型的几种方法(总结)
2017/03/04 PHP
Thinkphp5.0框架视图view的循环标签用法示例
2019/10/12 PHP
基于jquery的地址栏射击游戏代码
2011/03/10 Javascript
jQuery的选择器中的通配符使用介绍
2014/03/20 Javascript
JavaScript观察者模式(经典)
2015/12/09 Javascript
Bootstrap三种表单布局的使用方法
2016/06/21 Javascript
轻松实现js弹框显示选项
2016/09/13 Javascript
Angular2平滑升级到Angular4的步骤详解
2017/03/29 Javascript
element vue Array数组和Map对象的添加与删除操作
2018/11/14 Javascript
小程序页面动态配置实现方法
2019/02/05 Javascript
JS实现关闭小广告特效
2021/01/29 Javascript
vue实现验证用户名是否可用
2021/01/20 Vue.js
[02:43]DOTA2亚洲邀请赛场馆攻略——带你走进东方体育中心
2018/03/19 DOTA
详解python实现读取邮件数据并下载附件的实例
2017/08/03 Python
Python实现将照片变成卡通图片的方法【基于opencv】
2018/01/17 Python
PyCharm安装第三方库如Requests的图文教程
2018/05/18 Python
Python实现输入二叉树的先序和中序遍历,再输出后序遍历操作示例
2018/07/27 Python
pycharm运行程序时在Python console窗口中运行的方法
2018/12/03 Python
Python3中函数参数传递方式实例详解
2019/05/05 Python
基于Python实现大文件分割和命名脚本过程解析
2019/09/29 Python
python 读取数据库并绘图的实例
2019/12/03 Python
Python制作数据预测集成工具(值得收藏)
2020/08/21 Python
HTML5新表单元素_动力节点Java学院整理
2017/07/12 HTML / CSS
详解canvas绘制多张图的排列顺序问题
2019/01/21 HTML / CSS
德国街头和运动文化高品质商店:BSTN Store
2017/08/26 全球购物
Bluebella德国官网:英国性感内衣和睡衣品牌
2019/11/08 全球购物
英国最大的独立摄影零售商:Park Cameras
2019/11/27 全球购物
在C中是否有模拟继承等面向对象程序设计特性的好方法
2012/05/22 面试题
幼儿园体操比赛口号
2015/12/25 职场文书
oracle DGMGRL ORA-16603报错的解决方法(DG Broker)
2021/04/06 Oracle
python playwright 自动等待和断言详解
2021/11/27 Python
Go微服务项目配置文件的定义和读取示例详解
2022/06/21 Golang