神经网络(BP)算法Python实现及应用


Posted in Python onApril 16, 2018

本文实例为大家分享了Python实现神经网络算法及应用的具体代码,供大家参考,具体内容如下

首先用Python实现简单地神经网络算法:

import numpy as np


# 定义tanh函数
def tanh(x):
  return np.tanh(x)


# tanh函数的导数
def tan_deriv(x):
  return 1.0 - np.tanh(x) * np.tan(x)


# sigmoid函数
def logistic(x):
  return 1 / (1 + np.exp(-x))


# sigmoid函数的导数
def logistic_derivative(x):
  return logistic(x) * (1 - logistic(x))


class NeuralNetwork:
  def __init__(self, layers, activation='tanh'):
    """
    神经网络算法构造函数
    :param layers: 神经元层数
    :param activation: 使用的函数(默认tanh函数)
    :return:none
    """
    if activation == 'logistic':
      self.activation = logistic
      self.activation_deriv = logistic_derivative
    elif activation == 'tanh':
      self.activation = tanh
      self.activation_deriv = tan_deriv

    # 权重列表
    self.weights = []
    # 初始化权重(随机)
    for i in range(1, len(layers) - 1):
      self.weights.append((2 * np.random.random((layers[i - 1] + 1, layers[i] + 1)) - 1) * 0.25)
      self.weights.append((2 * np.random.random((layers[i] + 1, layers[i + 1])) - 1) * 0.25)

  def fit(self, X, y, learning_rate=0.2, epochs=10000):
    """
    训练神经网络
    :param X: 数据集(通常是二维)
    :param y: 分类标记
    :param learning_rate: 学习率(默认0.2)
    :param epochs: 训练次数(最大循环次数,默认10000)
    :return: none
    """
    # 确保数据集是二维的
    X = np.atleast_2d(X)

    temp = np.ones([X.shape[0], X.shape[1] + 1])
    temp[:, 0: -1] = X
    X = temp
    y = np.array(y)

    for k in range(epochs):
      # 随机抽取X的一行
      i = np.random.randint(X.shape[0])
      # 用随机抽取的这一组数据对神经网络更新
      a = [X[i]]
      # 正向更新
      for l in range(len(self.weights)):
        a.append(self.activation(np.dot(a[l], self.weights[l])))
      error = y[i] - a[-1]
      deltas = [error * self.activation_deriv(a[-1])]

      # 反向更新
      for l in range(len(a) - 2, 0, -1):
        deltas.append(deltas[-1].dot(self.weights[l].T) * self.activation_deriv(a[l]))
        deltas.reverse()
      for i in range(len(self.weights)):
        layer = np.atleast_2d(a[i])
        delta = np.atleast_2d(deltas[i])
        self.weights[i] += learning_rate * layer.T.dot(delta)

  def predict(self, x):
    x = np.array(x)
    temp = np.ones(x.shape[0] + 1)
    temp[0:-1] = x
    a = temp
    for l in range(0, len(self.weights)):
      a = self.activation(np.dot(a, self.weights[l]))
    return a

使用自己定义的神经网络算法实现一些简单的功能:

 小案例:

X:                  Y
0 0                 0
0 1                 1
1 0                 1
1 1                 0

from NN.NeuralNetwork import NeuralNetwork
import numpy as np

nn = NeuralNetwork([2, 2, 1], 'tanh')
temp = [[0, 0], [0, 1], [1, 0], [1, 1]]
X = np.array(temp)
y = np.array([0, 1, 1, 0])
nn.fit(X, y)
for i in temp:
  print(i, nn.predict(i))

神经网络(BP)算法Python实现及应用

发现结果基本机制,无限接近0或者无限接近1 

第二个例子:识别图片中的数字

导入数据:

from sklearn.datasets import load_digits
import pylab as pl

digits = load_digits()
print(digits.data.shape)
pl.gray()
pl.matshow(digits.images[0])
pl.show()

观察下:大小:(1797, 64)

数字0

神经网络(BP)算法Python实现及应用

接下来的代码是识别它们:

import numpy as np
from sklearn.datasets import load_digits
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelBinarizer
from NN.NeuralNetwork import NeuralNetwork
from sklearn.cross_validation import train_test_split

# 加载数据集
digits = load_digits()
X = digits.data
y = digits.target
# 处理数据,使得数据处于0,1之间,满足神经网络算法的要求
X -= X.min()
X /= X.max()

# 层数:
# 输出层10个数字
# 输入层64因为图片是8*8的,64像素
# 隐藏层假设100
nn = NeuralNetwork([64, 100, 10], 'logistic')
# 分隔训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y)

# 转化成sklearn需要的二维数据类型
labels_train = LabelBinarizer().fit_transform(y_train)
labels_test = LabelBinarizer().fit_transform(y_test)
print("start fitting")
# 训练3000次
nn.fit(X_train, labels_train, epochs=3000)
predictions = []
for i in range(X_test.shape[0]):
  o = nn.predict(X_test[i])
  # np.argmax:第几个数对应最大概率值
  predictions.append(np.argmax(o))

# 打印预测相关信息
print(confusion_matrix(y_test, predictions))
print(classification_report(y_test, predictions))

结果:

矩阵对角线代表预测正确的数量,发现正确率很多

神经网络(BP)算法Python实现及应用

这张表更直观地显示出预测正确率:

共450个案例,成功率94%

神经网络(BP)算法Python实现及应用

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于Tkinter库实现简单文本编辑器实例
May 05 Python
python中 chr unichr ord函数的实例详解
Aug 06 Python
教你使用python实现微信每天给女朋友说晚安
Mar 23 Python
python获取url的返回信息方法
Dec 17 Python
Python 安装第三方库 pip install 安装慢安装不上的解决办法
Jun 18 Python
django 实现celery动态设置周期任务执行时间
Nov 19 Python
python中的逆序遍历实例
Dec 25 Python
Python转换字典成为对象,可以用"."方式访问对象属性实例
May 11 Python
python中有帮助函数吗
Jun 19 Python
python3.4中清屏的处理方法
Jul 06 Python
基于Python实现体育彩票选号器功能代码实例
Sep 16 Python
python快速安装OpenCV的步骤记录
Feb 22 Python
python读取视频流提取视频帧的两种方法
Oct 22 #Python
python读取和保存视频文件
Apr 16 #Python
Python读取视频的两种方法(imageio和cv2)
Apr 15 #Python
python2.7实现FTP文件下载功能
Apr 15 #Python
python实现多线程网页下载器
Apr 15 #Python
Python实现定时精度可调节的定时器
Apr 15 #Python
Python编写一个优美的下载器
Apr 15 #Python
You might like
生成静态页面的php函数,php爱好者站推荐
2007/03/19 PHP
PHP中“=>
2019/03/01 PHP
jQuery ui 1.7更新小结
2009/08/15 Javascript
js switch case default 的用法示例介绍
2013/10/23 Javascript
js实现局部页面打印预览原理及示例代码
2014/07/03 Javascript
javascript将url中的参数加密解密代码
2014/11/17 Javascript
JavaScript仿支付宝密码输入框
2015/12/29 Javascript
jquery 遍历数组 each 方法详解
2016/05/25 Javascript
Bootstrap+jfinal退出系统弹出确认框的实现方法
2016/05/30 Javascript
JavaScript每天必学之基础知识
2016/09/17 Javascript
js 提交form表单和设置form表单请求路径的实现方法
2016/10/25 Javascript
js实现拖拽上传图片功能
2017/08/01 Javascript
jquery实现倒计时小应用
2017/09/19 jQuery
优雅的在React项目中使用Redux的方法
2018/11/10 Javascript
JavaScript 面向对象程序设计详解【类的创建、实例对象、构造函数、原型等】
2020/05/12 Javascript
[02:57]DOTA2亚洲邀请赛小组赛第四日 赛事回顾
2015/02/02 DOTA
[00:12]DAC2018 天才少年转战三号位,他的SOLO是否仍如昔日般强大?
2018/04/06 DOTA
详细解读tornado协程(coroutine)原理
2018/01/15 Python
python利用插值法对折线进行平滑曲线处理
2018/12/25 Python
使用django的objects.filter()方法匹配多个关键字的方法
2019/07/18 Python
python基于pdfminer库提取pdf文字代码实例
2019/08/15 Python
css3实现信纸/同学录效果的示例代码
2018/12/11 HTML / CSS
使用JS+CSS3技术:让你的名字动起来
2013/04/27 HTML / CSS
快速创建 HTML5 Canvas 电信网络拓扑图的示例代码
2018/03/21 HTML / CSS
阿提哈德航空官方网站:Etihad Airways
2017/01/06 全球购物
康帕斯酒店预订:Compass Hospitality(支持中文)
2018/08/23 全球购物
巴西补充剂和维生素购物网站:Natue
2019/06/17 全球购物
SQL里面如何插入自动增长序列号字段
2012/03/29 面试题
材料物理专业大学毕业生求职信
2013/10/15 职场文书
销售主管的自我评价分享
2014/01/03 职场文书
《有趣的发现》教学反思
2014/04/15 职场文书
外联部演讲稿
2014/05/24 职场文书
试用期转正后的自我评价
2014/09/21 职场文书
民主评议政风行风活动心得体会
2014/10/29 职场文书
物业前台接待岗位职责
2015/04/03 职场文书
有关骆驼祥子的读书笔记
2015/06/26 职场文书