如何用Python 实现全连接神经网络(Multi-layer Perceptron)


Posted in Python onOctober 15, 2020

代码

import numpy as np

# 各种激活函数及导数
def sigmoid(x):
  return 1 / (1 + np.exp(-x))


def dsigmoid(y):
  return y * (1 - y)


def tanh(x):
  return np.tanh(x)


def dtanh(y):
  return 1.0 - y ** 2


def relu(y):
  tmp = y.copy()
  tmp[tmp < 0] = 0
  return tmp


def drelu(x):
  tmp = x.copy()
  tmp[tmp >= 0] = 1
  tmp[tmp < 0] = 0
  return tmp


class MLPClassifier(object):
  """多层感知机,BP 算法训练"""

  def __init__(self,
         layers,
         activation='tanh',
         epochs=20, batch_size=1, learning_rate=0.01):
    """
    :param layers: 网络层结构
    :param activation: 激活函数
    :param epochs: 迭代轮次
    :param learning_rate: 学习率 
    """
    self.epochs = epochs
    self.learning_rate = learning_rate
    self.layers = []
    self.weights = []
    self.batch_size = batch_size

    for i in range(0, len(layers) - 1):
      weight = np.random.random((layers[i], layers[i + 1]))
      layer = np.ones(layers[i])
      self.layers.append(layer)
      self.weights.append(weight)
    self.layers.append(np.ones(layers[-1]))

    self.thresholds = []
    for i in range(1, len(layers)):
      threshold = np.random.random(layers[i])
      self.thresholds.append(threshold)

    if activation == 'tanh':
      self.activation = tanh
      self.dactivation = dtanh
    elif activation == 'sigomid':
      self.activation = sigmoid
      self.dactivation = dsigmoid
    elif activation == 'relu':
      self.activation = relu
      self.dactivation = drelu

  def fit(self, X, y):
    """
    :param X_: shape = [n_samples, n_features] 
    :param y: shape = [n_samples] 
    :return: self
    """
    for _ in range(self.epochs * (X.shape[0] // self.batch_size)):
      i = np.random.choice(X.shape[0], self.batch_size)
      # i = np.random.randint(X.shape[0])
      self.update(X[i])
      self.back_propagate(y[i])

  def predict(self, X):
    """
    :param X: shape = [n_samples, n_features] 
    :return: shape = [n_samples]
    """
    self.update(X)
    return self.layers[-1].copy()

  def update(self, inputs):
    self.layers[0] = inputs
    for i in range(len(self.weights)):
      next_layer_in = self.layers[i] @ self.weights[i] - self.thresholds[i]
      self.layers[i + 1] = self.activation(next_layer_in)

  def back_propagate(self, y):
    errors = y - self.layers[-1]

    gradients = [(self.dactivation(self.layers[-1]) * errors).sum(axis=0)]

    self.thresholds[-1] -= self.learning_rate * gradients[-1]
    for i in range(len(self.weights) - 1, 0, -1):
      tmp = np.sum(gradients[-1] @ self.weights[i].T * self.dactivation(self.layers[i]), axis=0)
      gradients.append(tmp)
      self.thresholds[i - 1] -= self.learning_rate * gradients[-1] / self.batch_size
    gradients.reverse()
    for i in range(len(self.weights)):
      tmp = np.mean(self.layers[i], axis=0)
      self.weights[i] += self.learning_rate * tmp.reshape((-1, 1)) * gradients[i]

测试代码

import sklearn.datasets
import numpy as np

def plot_decision_boundary(pred_func, X, y, title=None):
  """分类器画图函数,可画出样本点和决策边界
  :param pred_func: predict函数
  :param X: 训练集X
  :param y: 训练集Y
  :return: None
  """

  # Set min and max values and give it some padding
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # Generate a grid of points with distance h between them
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # Predict the function value for the whole gid
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # Plot the contour and training examples
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)

  if title:
    plt.title(title)
  plt.show()


def test_mlp():
  X, y = sklearn.datasets.make_moons(200, noise=0.20)
  y = y.reshape((-1, 1))
  n = MLPClassifier((2, 3, 1), activation='tanh', epochs=300, learning_rate=0.01)
  n.fit(X, y)
  def tmp(X):
    sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)
    ans = sign(n.predict(X))
    return ans

  plot_decision_boundary(tmp, X, y, 'Neural Network')

效果

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是如何用Python 实现全连接神经网络(Multi-layer Perceptron)的详细内容,更多关于Python 实现全连接神经网络的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python全局变量用法实例分析
Jul 19 Python
python getopt详解及简单实例
Dec 30 Python
python数字图像处理之高级滤波代码详解
Nov 23 Python
基于Django用户认证系统详解
Feb 21 Python
Python结合ImageMagick实现多张图片合并为一个pdf文件的方法
Apr 24 Python
利用Pandas读取文件路径或文件名称包含中文的csv文件方法
Jul 04 Python
python 递归深度优先搜索与广度优先搜索算法模拟实现
Oct 22 Python
对numpy中向量式三目运算符详解
Oct 31 Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 Python
Selenium环境变量配置(火狐浏览器)及验证实现
Dec 07 Python
Python读写Excel表格的方法
Mar 02 Python
Pandas 稀疏数据结构的实现
Jul 25 Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
python实现粒子群算法
Oct 15 #Python
如何将anaconda安装配置的mmdetection环境离线拷贝到另一台电脑
Oct 15 #Python
Python3.7安装PyQt5 运行配置Pycharm的详细教程
Oct 15 #Python
python利用faker库批量生成测试数据
Oct 15 #Python
如何利用python检测图片是否包含二维码
Oct 15 #Python
You might like
PHP 生成的XML以FLASH获取为乱码终极解决
2009/08/07 PHP
php 5.3.5安装memcache注意事项小结
2011/04/12 PHP
PHP学习散记_编码(json_encode 中文不显示)
2011/11/10 PHP
PHP7标量类型declare用法实例分析
2016/09/26 PHP
来自chinaz的ajax获取评论代码
2008/05/03 Javascript
js Date自定义函数 延迟脚本执行
2010/03/10 Javascript
Struts2的s:radio标签使用及用jquery添加change事件
2013/04/08 Javascript
js实现鼠标点击文本框自动选中内容的方法
2015/08/20 Javascript
iScroll.js 使用方法参考
2016/05/16 Javascript
bootstrap输入框组使用方法
2017/02/07 Javascript
AngularJS中的拦截器实例详解
2017/04/07 Javascript
python爬取安居客二手房网站数据(实例讲解)
2017/10/19 Javascript
springMvc 前端用json的方式向后台传递对象数组方法
2018/08/07 Javascript
Echart折线图手柄触发事件示例详解
2018/12/16 Javascript
原生JS实现弹幕效果的简单操作指南
2020/11/10 Javascript
JS前端基于canvas给图片添加水印
2020/11/11 Javascript
[20:21]《一刀刀一天》第十六期:TI国际邀请赛正式打响,总奖金超过550万
2014/05/23 DOTA
Python中请使用isinstance()判断变量类型
2014/08/25 Python
举例讲解Python面向对象编程中类的继承
2016/06/17 Python
python中解析json格式文件的方法示例
2017/05/03 Python
Python操作MySQL数据库的三种方法总结
2018/01/30 Python
opencv3/C++图像像素操作详解
2019/12/10 Python
Python利用matplotlib绘制折线图的新手教程
2020/11/05 Python
Python爬取某平台短视频的方法
2021/02/08 Python
css实例教程 一款纯css3实现的超炫动画背画特效
2014/11/05 HTML / CSS
一款CSS3实现多功能下拉菜单(带分享按)的教程
2014/11/05 HTML / CSS
html5开发三八女王节表白神器
2018/03/07 HTML / CSS
我有一个char * 型指针正巧指向一些int 型变量, 我想跳过它们。 为什么如下的代码((int *)p)++; 不行?
2013/05/09 面试题
机械工程师求职自我评价
2013/09/23 职场文书
房地产销售计划书
2014/01/10 职场文书
骨干教师考核方案
2014/05/09 职场文书
应届生求职信
2014/05/31 职场文书
水电工程师岗位职责
2015/02/13 职场文书
安全承诺书格式范本
2015/04/28 职场文书
2016大一新生军训心得体会
2016/01/11 职场文书
如何撰写出一份完美的商业计划书?
2019/07/12 职场文书