神经网络(BP)算法Python实现及应用


Posted in Python onApril 16, 2018

本文实例为大家分享了Python实现神经网络算法及应用的具体代码,供大家参考,具体内容如下

首先用Python实现简单地神经网络算法:

import numpy as np


# 定义tanh函数
def tanh(x):
  return np.tanh(x)


# tanh函数的导数
def tan_deriv(x):
  return 1.0 - np.tanh(x) * np.tan(x)


# sigmoid函数
def logistic(x):
  return 1 / (1 + np.exp(-x))


# sigmoid函数的导数
def logistic_derivative(x):
  return logistic(x) * (1 - logistic(x))


class NeuralNetwork:
  def __init__(self, layers, activation='tanh'):
    """
    神经网络算法构造函数
    :param layers: 神经元层数
    :param activation: 使用的函数(默认tanh函数)
    :return:none
    """
    if activation == 'logistic':
      self.activation = logistic
      self.activation_deriv = logistic_derivative
    elif activation == 'tanh':
      self.activation = tanh
      self.activation_deriv = tan_deriv

    # 权重列表
    self.weights = []
    # 初始化权重(随机)
    for i in range(1, len(layers) - 1):
      self.weights.append((2 * np.random.random((layers[i - 1] + 1, layers[i] + 1)) - 1) * 0.25)
      self.weights.append((2 * np.random.random((layers[i] + 1, layers[i + 1])) - 1) * 0.25)

  def fit(self, X, y, learning_rate=0.2, epochs=10000):
    """
    训练神经网络
    :param X: 数据集(通常是二维)
    :param y: 分类标记
    :param learning_rate: 学习率(默认0.2)
    :param epochs: 训练次数(最大循环次数,默认10000)
    :return: none
    """
    # 确保数据集是二维的
    X = np.atleast_2d(X)

    temp = np.ones([X.shape[0], X.shape[1] + 1])
    temp[:, 0: -1] = X
    X = temp
    y = np.array(y)

    for k in range(epochs):
      # 随机抽取X的一行
      i = np.random.randint(X.shape[0])
      # 用随机抽取的这一组数据对神经网络更新
      a = [X[i]]
      # 正向更新
      for l in range(len(self.weights)):
        a.append(self.activation(np.dot(a[l], self.weights[l])))
      error = y[i] - a[-1]
      deltas = [error * self.activation_deriv(a[-1])]

      # 反向更新
      for l in range(len(a) - 2, 0, -1):
        deltas.append(deltas[-1].dot(self.weights[l].T) * self.activation_deriv(a[l]))
        deltas.reverse()
      for i in range(len(self.weights)):
        layer = np.atleast_2d(a[i])
        delta = np.atleast_2d(deltas[i])
        self.weights[i] += learning_rate * layer.T.dot(delta)

  def predict(self, x):
    x = np.array(x)
    temp = np.ones(x.shape[0] + 1)
    temp[0:-1] = x
    a = temp
    for l in range(0, len(self.weights)):
      a = self.activation(np.dot(a, self.weights[l]))
    return a

使用自己定义的神经网络算法实现一些简单的功能:

 小案例:

X:                  Y
0 0                 0
0 1                 1
1 0                 1
1 1                 0

from NN.NeuralNetwork import NeuralNetwork
import numpy as np

nn = NeuralNetwork([2, 2, 1], 'tanh')
temp = [[0, 0], [0, 1], [1, 0], [1, 1]]
X = np.array(temp)
y = np.array([0, 1, 1, 0])
nn.fit(X, y)
for i in temp:
  print(i, nn.predict(i))

神经网络(BP)算法Python实现及应用

发现结果基本机制,无限接近0或者无限接近1 

第二个例子:识别图片中的数字

导入数据:

from sklearn.datasets import load_digits
import pylab as pl

digits = load_digits()
print(digits.data.shape)
pl.gray()
pl.matshow(digits.images[0])
pl.show()

观察下:大小:(1797, 64)

数字0

神经网络(BP)算法Python实现及应用

接下来的代码是识别它们:

import numpy as np
from sklearn.datasets import load_digits
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelBinarizer
from NN.NeuralNetwork import NeuralNetwork
from sklearn.cross_validation import train_test_split

# 加载数据集
digits = load_digits()
X = digits.data
y = digits.target
# 处理数据,使得数据处于0,1之间,满足神经网络算法的要求
X -= X.min()
X /= X.max()

# 层数:
# 输出层10个数字
# 输入层64因为图片是8*8的,64像素
# 隐藏层假设100
nn = NeuralNetwork([64, 100, 10], 'logistic')
# 分隔训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y)

# 转化成sklearn需要的二维数据类型
labels_train = LabelBinarizer().fit_transform(y_train)
labels_test = LabelBinarizer().fit_transform(y_test)
print("start fitting")
# 训练3000次
nn.fit(X_train, labels_train, epochs=3000)
predictions = []
for i in range(X_test.shape[0]):
  o = nn.predict(X_test[i])
  # np.argmax:第几个数对应最大概率值
  predictions.append(np.argmax(o))

# 打印预测相关信息
print(confusion_matrix(y_test, predictions))
print(classification_report(y_test, predictions))

结果:

矩阵对角线代表预测正确的数量,发现正确率很多

神经网络(BP)算法Python实现及应用

这张表更直观地显示出预测正确率:

共450个案例,成功率94%

神经网络(BP)算法Python实现及应用

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 控制语句
Nov 03 Python
为python设置socket代理的方法
Jan 14 Python
Tensorflow 利用tf.contrib.learn建立输入函数的方法
Feb 08 Python
PyQt5 QSerialPort子线程操作的实现
Apr 21 Python
详解关于Django中ORM数据库迁移的配置
Oct 08 Python
详解Python Matplotlib解决绘图X轴值不按数组排序问题
Aug 05 Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 Python
python使用建议与技巧分享(一)
Aug 17 Python
Python 通过爬虫实现GitHub网页的模拟登录的示例代码
Aug 17 Python
python定时截屏实现
Nov 02 Python
Elasticsearch 批量操作
Apr 19 Python
Python实现Matplotlib,Seaborn动态数据图
May 06 Python
python读取视频流提取视频帧的两种方法
Oct 22 #Python
python读取和保存视频文件
Apr 16 #Python
Python读取视频的两种方法(imageio和cv2)
Apr 15 #Python
python2.7实现FTP文件下载功能
Apr 15 #Python
python实现多线程网页下载器
Apr 15 #Python
Python实现定时精度可调节的定时器
Apr 15 #Python
Python编写一个优美的下载器
Apr 15 #Python
You might like
PHP令牌 Token改进版
2008/07/18 PHP
PHP英文字母大小写转换函数小结
2014/05/03 PHP
大家在抢红包,程序员在研究红包算法
2015/08/31 PHP
PHP开发之归档格式phar文件概念与用法详解【创建,使用,解包还原提取】
2017/11/17 PHP
php+laravel依赖注入知识点总结
2019/11/04 PHP
php反序列化长度变化尾部字符串逃逸(0CTF-2016-piapiapia)
2020/02/15 PHP
通过MSXML2自动获取QQ个人头像及在线情况(给初学者)
2007/01/22 Javascript
return false;和e.preventDefault();的区别
2010/07/11 Javascript
jquery ajax 同步异步的执行 return值不能取得的解决方案
2012/01/08 Javascript
jquerymobile checkbox及时刷新才能获取其准确值
2012/04/14 Javascript
JavaScript 参数中的数组展开 [译]
2012/09/21 Javascript
JavaScript 创建运动框架的实现代码
2013/05/08 Javascript
node.js中的console用法总结
2014/12/15 Javascript
Bootstrap3学习笔记(二)之排版
2016/05/20 Javascript
angularjs 源码解析之injector
2016/08/22 Javascript
Bootstrap Search Suggest使用例子
2016/12/21 Javascript
jQuery插件autocomplete使用详解
2017/02/04 Javascript
jQuery实现弹窗居中效果类似alert()
2017/02/27 Javascript
基于JavaScript实现图片连播和联级菜单实例代码
2017/07/28 Javascript
node.js中使用Export和Import的方法
2017/09/18 Javascript
Vue-router结合transition实现app前进后退动画切换效果的实例
2017/10/11 Javascript
从源码看angular/material2 中 dialog模块的实现方法
2017/10/18 Javascript
详解wepy开发小程序踩过的坑(小结)
2019/05/22 Javascript
express框架中使用jwt实现验证的方法
2019/08/25 Javascript
十分钟教你上手ES2020新特性
2020/02/12 Javascript
Python操作Redis之设置key的过期时间实例代码
2018/01/25 Python
2014组织生活会方案
2014/05/19 职场文书
2014年冬季防火方案
2014/05/21 职场文书
会计求职信范文
2014/05/24 职场文书
上课迟到检讨书范文
2015/05/06 职场文书
2015年电信员工工作总结
2015/05/26 职场文书
2015年教研员工作总结
2015/05/26 职场文书
PHP策略模式写法
2021/04/01 PHP
Pytorch 如何实现LSTM时间序列预测
2021/05/17 Python
动态规划之使用备忘录来改进Javascript函数
2022/04/07 Javascript
golang用type-switch判断interface的实际存储类型
2022/04/14 Golang