如何用Python 实现全连接神经网络(Multi-layer Perceptron)


Posted in Python onOctober 15, 2020

代码

import numpy as np

# 各种激活函数及导数
def sigmoid(x):
  return 1 / (1 + np.exp(-x))


def dsigmoid(y):
  return y * (1 - y)


def tanh(x):
  return np.tanh(x)


def dtanh(y):
  return 1.0 - y ** 2


def relu(y):
  tmp = y.copy()
  tmp[tmp < 0] = 0
  return tmp


def drelu(x):
  tmp = x.copy()
  tmp[tmp >= 0] = 1
  tmp[tmp < 0] = 0
  return tmp


class MLPClassifier(object):
  """多层感知机,BP 算法训练"""

  def __init__(self,
         layers,
         activation='tanh',
         epochs=20, batch_size=1, learning_rate=0.01):
    """
    :param layers: 网络层结构
    :param activation: 激活函数
    :param epochs: 迭代轮次
    :param learning_rate: 学习率 
    """
    self.epochs = epochs
    self.learning_rate = learning_rate
    self.layers = []
    self.weights = []
    self.batch_size = batch_size

    for i in range(0, len(layers) - 1):
      weight = np.random.random((layers[i], layers[i + 1]))
      layer = np.ones(layers[i])
      self.layers.append(layer)
      self.weights.append(weight)
    self.layers.append(np.ones(layers[-1]))

    self.thresholds = []
    for i in range(1, len(layers)):
      threshold = np.random.random(layers[i])
      self.thresholds.append(threshold)

    if activation == 'tanh':
      self.activation = tanh
      self.dactivation = dtanh
    elif activation == 'sigomid':
      self.activation = sigmoid
      self.dactivation = dsigmoid
    elif activation == 'relu':
      self.activation = relu
      self.dactivation = drelu

  def fit(self, X, y):
    """
    :param X_: shape = [n_samples, n_features] 
    :param y: shape = [n_samples] 
    :return: self
    """
    for _ in range(self.epochs * (X.shape[0] // self.batch_size)):
      i = np.random.choice(X.shape[0], self.batch_size)
      # i = np.random.randint(X.shape[0])
      self.update(X[i])
      self.back_propagate(y[i])

  def predict(self, X):
    """
    :param X: shape = [n_samples, n_features] 
    :return: shape = [n_samples]
    """
    self.update(X)
    return self.layers[-1].copy()

  def update(self, inputs):
    self.layers[0] = inputs
    for i in range(len(self.weights)):
      next_layer_in = self.layers[i] @ self.weights[i] - self.thresholds[i]
      self.layers[i + 1] = self.activation(next_layer_in)

  def back_propagate(self, y):
    errors = y - self.layers[-1]

    gradients = [(self.dactivation(self.layers[-1]) * errors).sum(axis=0)]

    self.thresholds[-1] -= self.learning_rate * gradients[-1]
    for i in range(len(self.weights) - 1, 0, -1):
      tmp = np.sum(gradients[-1] @ self.weights[i].T * self.dactivation(self.layers[i]), axis=0)
      gradients.append(tmp)
      self.thresholds[i - 1] -= self.learning_rate * gradients[-1] / self.batch_size
    gradients.reverse()
    for i in range(len(self.weights)):
      tmp = np.mean(self.layers[i], axis=0)
      self.weights[i] += self.learning_rate * tmp.reshape((-1, 1)) * gradients[i]

测试代码

import sklearn.datasets
import numpy as np

def plot_decision_boundary(pred_func, X, y, title=None):
  """分类器画图函数,可画出样本点和决策边界
  :param pred_func: predict函数
  :param X: 训练集X
  :param y: 训练集Y
  :return: None
  """

  # Set min and max values and give it some padding
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # Generate a grid of points with distance h between them
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # Predict the function value for the whole gid
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # Plot the contour and training examples
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)

  if title:
    plt.title(title)
  plt.show()


def test_mlp():
  X, y = sklearn.datasets.make_moons(200, noise=0.20)
  y = y.reshape((-1, 1))
  n = MLPClassifier((2, 3, 1), activation='tanh', epochs=300, learning_rate=0.01)
  n.fit(X, y)
  def tmp(X):
    sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)
    ans = sign(n.predict(X))
    return ans

  plot_decision_boundary(tmp, X, y, 'Neural Network')

效果

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是如何用Python 实现全连接神经网络(Multi-layer Perceptron)的详细内容,更多关于Python 实现全连接神经网络的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python数据处理实战(必看篇)
Jun 11 Python
Python竟能画这么漂亮的花,帅呆了(代码分享)
Nov 15 Python
Python+matplotlib实现华丽的文本框演示代码
Jan 22 Python
python 脚本生成随机 字母 + 数字密码功能
May 26 Python
对Python的zip函数妙用,旋转矩阵详解
Dec 13 Python
Python基于聚类算法实现密度聚类(DBSCAN)计算【测试可用】
Dec 26 Python
python的pyecharts绘制各种图表详细(附代码)
Nov 11 Python
Python django搭建layui提交表单,表格,图标的实例
Nov 18 Python
基于Tensorflow批量数据的输入实现方式
Feb 05 Python
python获取百度热榜链接的实例方法
Aug 25 Python
python报错TypeError: ‘NoneType‘ object is not subscriptable的解决方法
Nov 05 Python
pycharm 的Structure界面设置操作
Feb 05 Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
python实现粒子群算法
Oct 15 #Python
如何将anaconda安装配置的mmdetection环境离线拷贝到另一台电脑
Oct 15 #Python
Python3.7安装PyQt5 运行配置Pycharm的详细教程
Oct 15 #Python
python利用faker库批量生成测试数据
Oct 15 #Python
如何利用python检测图片是否包含二维码
Oct 15 #Python
You might like
PHP 程序员的调试技术小结
2009/11/15 PHP
php中记录用户访问过的产品,在cookie记录产品id,id取得产品信息
2011/05/04 PHP
PHP的基本常识小结
2013/07/05 PHP
关于PHP通用返回值设置方法
2017/03/31 PHP
PHP文件上传小程序 适合初学者学习!
2019/05/23 PHP
PHP sdk实现在线打包代码示例
2020/12/09 PHP
CL vs ForZe BO5 第四场 2.13
2021/03/10 DOTA
jQuery .attr()和.removeAttr()方法操作元素属性示例
2013/07/16 Javascript
JS 页面计时器示例代码
2013/10/28 Javascript
jQuery截取指定长度字符串的实现原理及代码
2014/07/01 Javascript
JavaScript中计算网页中某个元素的位置
2015/06/10 Javascript
js停止冒泡和阻止浏览器默认行为的简单方法
2016/05/15 Javascript
JS中Select下拉列表类(支持输入模糊查询)功能
2017/01/17 Javascript
jQuery.cookie.js实现记录最近浏览过的商品功能示例
2017/01/23 Javascript
JS批量替换内容中关键词为超链接
2017/02/20 Javascript
js实现鼠标拖动功能
2017/03/20 Javascript
解决bootstrap模态框数据缓存的问题方法
2018/08/10 Javascript
angularJs中跳转到指定的锚点实例($anchorScroll)
2018/08/31 Javascript
angularJS1 url中携带参数的获取方法
2018/10/09 Javascript
swiper4实现移动端导航切换
2020/10/16 Javascript
使用Taro实现小程序商城的购物车功能模块的实例代码
2020/06/05 Javascript
Vue绑定用户接口实现代码示例
2020/11/04 Javascript
深入理解Django自定义信号(signals)
2018/10/15 Python
python导入模块交叉引用的方法
2019/01/19 Python
QML使用Python的函数过程解析
2019/09/26 Python
python django中8000端口被占用的解决
2019/12/17 Python
python 窃取摄像头照片的实现示例
2021/01/08 Python
澳大利亚在线时尚精品店:Hello Molly
2018/02/26 全球购物
英国女性运动服品牌:Sweaty Betty
2018/11/08 全球购物
介绍一下EJB的体系结构
2012/08/01 面试题
应届大学生的推荐信
2013/11/20 职场文书
工作表扬信的范文
2014/01/10 职场文书
收费员岗位职责
2015/02/14 职场文书
英语导游欢迎词
2015/09/30 职场文书
建房合同协议书
2016/03/21 职场文书
MySQL 逻辑备份与恢复测试的相关总结
2021/05/14 MySQL