如何用Python 实现全连接神经网络(Multi-layer Perceptron)


Posted in Python onOctober 15, 2020

代码

import numpy as np

# 各种激活函数及导数
def sigmoid(x):
  return 1 / (1 + np.exp(-x))


def dsigmoid(y):
  return y * (1 - y)


def tanh(x):
  return np.tanh(x)


def dtanh(y):
  return 1.0 - y ** 2


def relu(y):
  tmp = y.copy()
  tmp[tmp < 0] = 0
  return tmp


def drelu(x):
  tmp = x.copy()
  tmp[tmp >= 0] = 1
  tmp[tmp < 0] = 0
  return tmp


class MLPClassifier(object):
  """多层感知机,BP 算法训练"""

  def __init__(self,
         layers,
         activation='tanh',
         epochs=20, batch_size=1, learning_rate=0.01):
    """
    :param layers: 网络层结构
    :param activation: 激活函数
    :param epochs: 迭代轮次
    :param learning_rate: 学习率 
    """
    self.epochs = epochs
    self.learning_rate = learning_rate
    self.layers = []
    self.weights = []
    self.batch_size = batch_size

    for i in range(0, len(layers) - 1):
      weight = np.random.random((layers[i], layers[i + 1]))
      layer = np.ones(layers[i])
      self.layers.append(layer)
      self.weights.append(weight)
    self.layers.append(np.ones(layers[-1]))

    self.thresholds = []
    for i in range(1, len(layers)):
      threshold = np.random.random(layers[i])
      self.thresholds.append(threshold)

    if activation == 'tanh':
      self.activation = tanh
      self.dactivation = dtanh
    elif activation == 'sigomid':
      self.activation = sigmoid
      self.dactivation = dsigmoid
    elif activation == 'relu':
      self.activation = relu
      self.dactivation = drelu

  def fit(self, X, y):
    """
    :param X_: shape = [n_samples, n_features] 
    :param y: shape = [n_samples] 
    :return: self
    """
    for _ in range(self.epochs * (X.shape[0] // self.batch_size)):
      i = np.random.choice(X.shape[0], self.batch_size)
      # i = np.random.randint(X.shape[0])
      self.update(X[i])
      self.back_propagate(y[i])

  def predict(self, X):
    """
    :param X: shape = [n_samples, n_features] 
    :return: shape = [n_samples]
    """
    self.update(X)
    return self.layers[-1].copy()

  def update(self, inputs):
    self.layers[0] = inputs
    for i in range(len(self.weights)):
      next_layer_in = self.layers[i] @ self.weights[i] - self.thresholds[i]
      self.layers[i + 1] = self.activation(next_layer_in)

  def back_propagate(self, y):
    errors = y - self.layers[-1]

    gradients = [(self.dactivation(self.layers[-1]) * errors).sum(axis=0)]

    self.thresholds[-1] -= self.learning_rate * gradients[-1]
    for i in range(len(self.weights) - 1, 0, -1):
      tmp = np.sum(gradients[-1] @ self.weights[i].T * self.dactivation(self.layers[i]), axis=0)
      gradients.append(tmp)
      self.thresholds[i - 1] -= self.learning_rate * gradients[-1] / self.batch_size
    gradients.reverse()
    for i in range(len(self.weights)):
      tmp = np.mean(self.layers[i], axis=0)
      self.weights[i] += self.learning_rate * tmp.reshape((-1, 1)) * gradients[i]

测试代码

import sklearn.datasets
import numpy as np

def plot_decision_boundary(pred_func, X, y, title=None):
  """分类器画图函数,可画出样本点和决策边界
  :param pred_func: predict函数
  :param X: 训练集X
  :param y: 训练集Y
  :return: None
  """

  # Set min and max values and give it some padding
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # Generate a grid of points with distance h between them
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # Predict the function value for the whole gid
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # Plot the contour and training examples
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)

  if title:
    plt.title(title)
  plt.show()


def test_mlp():
  X, y = sklearn.datasets.make_moons(200, noise=0.20)
  y = y.reshape((-1, 1))
  n = MLPClassifier((2, 3, 1), activation='tanh', epochs=300, learning_rate=0.01)
  n.fit(X, y)
  def tmp(X):
    sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)
    ans = sign(n.predict(X))
    return ans

  plot_decision_boundary(tmp, X, y, 'Neural Network')

效果

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是如何用Python 实现全连接神经网络(Multi-layer Perceptron)的详细内容,更多关于Python 实现全连接神经网络的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python使用正则表达式分析网页中的图片并进行替换的方法
Mar 26 Python
Python实现在matplotlib中两个坐标轴之间画一条直线光标的方法
May 20 Python
使用Python生成随机密码的示例分享
Feb 18 Python
使用python画个小猪佩奇的示例代码
Jun 06 Python
python 处理string到hex脚本的方法
Oct 26 Python
解决Python安装时报缺少DLL问题【两种解决方法】
Jul 15 Python
python Django的web开发实例(入门)
Jul 31 Python
python爬虫selenium和phantomJs使用方法解析
Aug 08 Python
pytorch 实现删除tensor中的指定行列
Jan 13 Python
手动安装python3.6的操作过程详解
Jan 13 Python
NumPy排序的实现
Jan 21 Python
python产生模拟数据faker库的使用详解
Nov 04 Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
python实现粒子群算法
Oct 15 #Python
如何将anaconda安装配置的mmdetection环境离线拷贝到另一台电脑
Oct 15 #Python
Python3.7安装PyQt5 运行配置Pycharm的详细教程
Oct 15 #Python
python利用faker库批量生成测试数据
Oct 15 #Python
如何利用python检测图片是否包含二维码
Oct 15 #Python
You might like
PHP中static关键字原理的学习研究分析
2011/07/18 PHP
php时间戳转换的示例
2014/03/31 PHP
fsockopen pfsockopen函数被禁用,SMTP发送邮件不正常的解决方法
2015/09/20 PHP
PHP数组生成XML格式数据的封装类实例
2016/11/10 PHP
javascript类继承机制的原理分析
2009/09/12 Javascript
javascript获取网页中指定节点的父节点、子节点的方法小结
2013/04/24 Javascript
js取两个数组的交集|差集|并集|补集|去重示例代码
2013/08/07 Javascript
Extjs表单常见验证小结
2014/03/07 Javascript
使用text方法获取Html元素文本信息示例
2014/09/01 Javascript
全面解析Bootstrap表单使用方法(表单按钮)
2015/11/24 Javascript
简单的vue-resourse获取json并应用到模板示例
2017/02/10 Javascript
详解vue.js移动端导航navigationbar的封装
2017/07/05 Javascript
Angular2 http jsonp的实例详解
2017/08/31 Javascript
使用Bootrap和Vue实现仿百度搜索功能
2017/10/26 Javascript
vue.js过滤器+ajax实现事件监听及后台php数据交互实例
2018/05/22 Javascript
利用Node.js批量抓取高清妹子图片实例教程
2018/08/02 Javascript
详解Angular cli配置过程记录
2019/11/07 Javascript
Vue实现剪贴板复制功能
2019/12/31 Javascript
[49:15]DOTA2-DPC中国联赛 正赛 CDEC vs XG BO3 第二场 1月19日
2021/03/11 DOTA
编写Python CGI脚本的教程
2015/06/29 Python
python实现汉诺塔方法汇总
2016/07/25 Python
Python numpy 提取矩阵的某一行或某一列的实例
2018/04/03 Python
使用 Python 玩转 GitHub 的贡献板(推荐)
2019/04/04 Python
python实现统计文本中单词出现的频率详解
2019/05/20 Python
Python3内置模块random随机方法小结
2019/07/13 Python
CSS3 二级导航菜单的制作的示例
2018/04/02 HTML / CSS
JAVA程序员面试题
2012/10/03 面试题
电子商务专业学生的自我鉴定
2013/11/28 职场文书
医务工作者先进事迹材料
2014/01/26 职场文书
QQ空间主人寄语大全
2014/04/12 职场文书
战略性融资合作协议书范本
2014/10/17 职场文书
医药公司采购员岗位职责
2015/04/03 职场文书
高中历史教学反思
2016/02/19 职场文书
Nginx流量拷贝ngx_http_mirror_module模块使用方法详解
2022/04/07 Servers
Java数据结构之堆(优先队列)
2022/05/20 Java/Android
Spring Boot项目如何优雅实现Excel导入与导出功能
2022/06/10 Java/Android