如何用Python 实现全连接神经网络(Multi-layer Perceptron)


Posted in Python onOctober 15, 2020

代码

import numpy as np

# 各种激活函数及导数
def sigmoid(x):
  return 1 / (1 + np.exp(-x))


def dsigmoid(y):
  return y * (1 - y)


def tanh(x):
  return np.tanh(x)


def dtanh(y):
  return 1.0 - y ** 2


def relu(y):
  tmp = y.copy()
  tmp[tmp < 0] = 0
  return tmp


def drelu(x):
  tmp = x.copy()
  tmp[tmp >= 0] = 1
  tmp[tmp < 0] = 0
  return tmp


class MLPClassifier(object):
  """多层感知机,BP 算法训练"""

  def __init__(self,
         layers,
         activation='tanh',
         epochs=20, batch_size=1, learning_rate=0.01):
    """
    :param layers: 网络层结构
    :param activation: 激活函数
    :param epochs: 迭代轮次
    :param learning_rate: 学习率 
    """
    self.epochs = epochs
    self.learning_rate = learning_rate
    self.layers = []
    self.weights = []
    self.batch_size = batch_size

    for i in range(0, len(layers) - 1):
      weight = np.random.random((layers[i], layers[i + 1]))
      layer = np.ones(layers[i])
      self.layers.append(layer)
      self.weights.append(weight)
    self.layers.append(np.ones(layers[-1]))

    self.thresholds = []
    for i in range(1, len(layers)):
      threshold = np.random.random(layers[i])
      self.thresholds.append(threshold)

    if activation == 'tanh':
      self.activation = tanh
      self.dactivation = dtanh
    elif activation == 'sigomid':
      self.activation = sigmoid
      self.dactivation = dsigmoid
    elif activation == 'relu':
      self.activation = relu
      self.dactivation = drelu

  def fit(self, X, y):
    """
    :param X_: shape = [n_samples, n_features] 
    :param y: shape = [n_samples] 
    :return: self
    """
    for _ in range(self.epochs * (X.shape[0] // self.batch_size)):
      i = np.random.choice(X.shape[0], self.batch_size)
      # i = np.random.randint(X.shape[0])
      self.update(X[i])
      self.back_propagate(y[i])

  def predict(self, X):
    """
    :param X: shape = [n_samples, n_features] 
    :return: shape = [n_samples]
    """
    self.update(X)
    return self.layers[-1].copy()

  def update(self, inputs):
    self.layers[0] = inputs
    for i in range(len(self.weights)):
      next_layer_in = self.layers[i] @ self.weights[i] - self.thresholds[i]
      self.layers[i + 1] = self.activation(next_layer_in)

  def back_propagate(self, y):
    errors = y - self.layers[-1]

    gradients = [(self.dactivation(self.layers[-1]) * errors).sum(axis=0)]

    self.thresholds[-1] -= self.learning_rate * gradients[-1]
    for i in range(len(self.weights) - 1, 0, -1):
      tmp = np.sum(gradients[-1] @ self.weights[i].T * self.dactivation(self.layers[i]), axis=0)
      gradients.append(tmp)
      self.thresholds[i - 1] -= self.learning_rate * gradients[-1] / self.batch_size
    gradients.reverse()
    for i in range(len(self.weights)):
      tmp = np.mean(self.layers[i], axis=0)
      self.weights[i] += self.learning_rate * tmp.reshape((-1, 1)) * gradients[i]

测试代码

import sklearn.datasets
import numpy as np

def plot_decision_boundary(pred_func, X, y, title=None):
  """分类器画图函数,可画出样本点和决策边界
  :param pred_func: predict函数
  :param X: 训练集X
  :param y: 训练集Y
  :return: None
  """

  # Set min and max values and give it some padding
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # Generate a grid of points with distance h between them
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # Predict the function value for the whole gid
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # Plot the contour and training examples
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)

  if title:
    plt.title(title)
  plt.show()


def test_mlp():
  X, y = sklearn.datasets.make_moons(200, noise=0.20)
  y = y.reshape((-1, 1))
  n = MLPClassifier((2, 3, 1), activation='tanh', epochs=300, learning_rate=0.01)
  n.fit(X, y)
  def tmp(X):
    sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)
    ans = sign(n.predict(X))
    return ans

  plot_decision_boundary(tmp, X, y, 'Neural Network')

效果

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是如何用Python 实现全连接神经网络(Multi-layer Perceptron)的详细内容,更多关于Python 实现全连接神经网络的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
windows下python连接oracle数据库
Jun 07 Python
Python网络爬虫与信息提取(实例讲解)
Aug 29 Python
Python网络爬虫神器PyQuery的基本使用教程
Feb 03 Python
Sanic框架Cookies操作示例
Jul 17 Python
pytorch在fintune时将sequential中的层输出方法,以vgg为例
Aug 20 Python
Python Django2.0集成Celery4.1教程
Nov 19 Python
python3 requests库实现多图片爬取教程
Dec 18 Python
Python利用Scrapy框架爬取豆瓣电影示例
Jan 17 Python
Python函数递归调用实现原理实例解析
Aug 11 Python
Python-OpenCV实现图像缺陷检测的实例
Jun 11 Python
python数据可视化使用pyfinance分析证券收益示例详解
Nov 20 Python
Python中with上下文管理协议的作用及用法
Mar 18 Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
python实现粒子群算法
Oct 15 #Python
如何将anaconda安装配置的mmdetection环境离线拷贝到另一台电脑
Oct 15 #Python
Python3.7安装PyQt5 运行配置Pycharm的详细教程
Oct 15 #Python
python利用faker库批量生成测试数据
Oct 15 #Python
如何利用python检测图片是否包含二维码
Oct 15 #Python
You might like
PHP网页游戏学习之Xnova(ogame)源码解读(五)
2014/06/23 PHP
PHP实现在线阅读PDF文件的方法
2015/06/23 PHP
深入讲解PHP Session及如何保持其不过期的方法
2015/08/18 PHP
smarty学习笔记之常见代码段用法总结
2016/03/19 PHP
PHP+原生态ajax实现的省市联动功能详解
2017/08/15 PHP
基于jQuery架构javascript基础体系
2011/01/01 Javascript
js 静态动态成员 and 信息的封装和隐藏
2011/05/29 Javascript
JavaScript 的继承
2011/10/01 Javascript
使用JSLint提高JS代码质量方法分享
2013/12/16 Javascript
jquery解决客户端跨域访问问题
2015/01/06 Javascript
JavaScript中的Math.atan2()方法使用详解
2015/06/15 Javascript
jquery小火箭返回顶部代码分享
2015/08/19 Javascript
使用Sticky组件实现带sticky效果的tab导航和滚动导航的方法
2016/03/22 Javascript
jQuery使用getJSON方法获取json数据完整示例
2016/09/13 Javascript
基于NodeJS+MongoDB+AngularJS+Bootstrap开发书店案例分析
2017/01/12 NodeJs
JavaScript数据结构之二叉树的删除算法示例
2017/04/13 Javascript
three.js中文文档学习之创建场景
2017/11/20 Javascript
nodejs结合socket.io实现websocket通信功能的方法
2018/01/12 NodeJs
Vue中使用ElementUI使用第三方图标库iconfont的示例
2018/10/11 Javascript
代码实例ajax实现点击加载更多数据图片
2018/10/12 Javascript
vue中使用cookies和crypto-js实现记住密码和加密的方法
2018/10/18 Javascript
新手简单了解vue
2019/05/29 Javascript
vue 修改 data 数据问题并实时显示操作
2020/09/07 Javascript
[47:10]完美世界DOTA2联赛PWL S3 LBZS vs Rebirth 第二场 12.16
2020/12/18 DOTA
Windows下的Jupyter Notebook 安装与自定义启动(图文详解)
2018/02/21 Python
Python爬虫小技巧之伪造随机的User-Agent
2018/09/13 Python
解决pycharm 远程调试 上传 helpers 卡住的问题
2019/06/27 Python
Html5 new XMLHttpRequest()监听附件上传进度
2021/01/14 HTML / CSS
马来西亚在线购物市场:PGMall.my
2019/10/13 全球购物
娇韵诗香港官网:Clarins香港
2020/08/13 全球购物
指针和引用有什么区别
2013/01/13 面试题
和谐社区口号
2014/06/19 职场文书
心灵点滴观后感
2015/06/02 职场文书
使用HTML+Css+transform实现3D导航栏的示例代码
2021/03/31 HTML / CSS
教你使用Pandas直接核算Excel中快递费用
2021/05/12 Python
JavaScript获取URL参数的方法分享
2022/04/07 Javascript