Python实现的三层BP神经网络算法示例


Posted in Python onFebruary 07, 2018

本文实例讲述了Python实现的三层BP神经网络算法。分享给大家供大家参考,具体如下:

这是一个非常漂亮的三层反向传播神经网络的python实现,下一步我准备试着将其修改为多层BP神经网络。

下面是运行演示函数的截图,你会发现预测的结果很惊人!

Python实现的三层BP神经网络算法示例

提示:运行演示函数的时候,可以尝试改变隐藏层的节点数,看节点数增加了,预测的精度会否提升

import math
import random
import string
random.seed(0)
# 生成区间[a, b)内的随机数
def rand(a, b):
 return (b-a)*random.random() + a
# 生成大小 I*J 的矩阵,默认零矩阵 (当然,亦可用 NumPy 提速)
def makeMatrix(I, J, fill=0.0):
 m = []
 for i in range(I):
  m.append([fill]*J)
 return m
# 函数 sigmoid,这里采用 tanh,因为看起来要比标准的 1/(1+e^-x) 漂亮些
def sigmoid(x):
 return math.tanh(x)
# 函数 sigmoid 的派生函数, 为了得到输出 (即:y)
def dsigmoid(y):
 return 1.0 - y**2
class NN:
 ''' 三层反向传播神经网络 '''
 def __init__(self, ni, nh, no):
  # 输入层、隐藏层、输出层的节点(数)
  self.ni = ni + 1 # 增加一个偏差节点
  self.nh = nh
  self.no = no
  # 激活神经网络的所有节点(向量)
  self.ai = [1.0]*self.ni
  self.ah = [1.0]*self.nh
  self.ao = [1.0]*self.no
  # 建立权重(矩阵)
  self.wi = makeMatrix(self.ni, self.nh)
  self.wo = makeMatrix(self.nh, self.no)
  # 设为随机值
  for i in range(self.ni):
   for j in range(self.nh):
    self.wi[i][j] = rand(-0.2, 0.2)
  for j in range(self.nh):
   for k in range(self.no):
    self.wo[j][k] = rand(-2.0, 2.0)
  # 最后建立动量因子(矩阵)
  self.ci = makeMatrix(self.ni, self.nh)
  self.co = makeMatrix(self.nh, self.no)
 def update(self, inputs):
  if len(inputs) != self.ni-1:
   raise ValueError('与输入层节点数不符!')
  # 激活输入层
  for i in range(self.ni-1):
   #self.ai[i] = sigmoid(inputs[i])
   self.ai[i] = inputs[i]
  # 激活隐藏层
  for j in range(self.nh):
   sum = 0.0
   for i in range(self.ni):
    sum = sum + self.ai[i] * self.wi[i][j]
   self.ah[j] = sigmoid(sum)
  # 激活输出层
  for k in range(self.no):
   sum = 0.0
   for j in range(self.nh):
    sum = sum + self.ah[j] * self.wo[j][k]
   self.ao[k] = sigmoid(sum)
  return self.ao[:]
 def backPropagate(self, targets, N, M):
  ''' 反向传播 '''
  if len(targets) != self.no:
   raise ValueError('与输出层节点数不符!')
  # 计算输出层的误差
  output_deltas = [0.0] * self.no
  for k in range(self.no):
   error = targets[k]-self.ao[k]
   output_deltas[k] = dsigmoid(self.ao[k]) * error
  # 计算隐藏层的误差
  hidden_deltas = [0.0] * self.nh
  for j in range(self.nh):
   error = 0.0
   for k in range(self.no):
    error = error + output_deltas[k]*self.wo[j][k]
   hidden_deltas[j] = dsigmoid(self.ah[j]) * error
  # 更新输出层权重
  for j in range(self.nh):
   for k in range(self.no):
    change = output_deltas[k]*self.ah[j]
    self.wo[j][k] = self.wo[j][k] + N*change + M*self.co[j][k]
    self.co[j][k] = change
    #print(N*change, M*self.co[j][k])
  # 更新输入层权重
  for i in range(self.ni):
   for j in range(self.nh):
    change = hidden_deltas[j]*self.ai[i]
    self.wi[i][j] = self.wi[i][j] + N*change + M*self.ci[i][j]
    self.ci[i][j] = change
  # 计算误差
  error = 0.0
  for k in range(len(targets)):
   error = error + 0.5*(targets[k]-self.ao[k])**2
  return error
 def test(self, patterns):
  for p in patterns:
   print(p[0], '->', self.update(p[0]))
 def weights(self):
  print('输入层权重:')
  for i in range(self.ni):
   print(self.wi[i])
  print()
  print('输出层权重:')
  for j in range(self.nh):
   print(self.wo[j])
 def train(self, patterns, iterations=1000, N=0.5, M=0.1):
  # N: 学习速率(learning rate)
  # M: 动量因子(momentum factor)
  for i in range(iterations):
   error = 0.0
   for p in patterns:
    inputs = p[0]
    targets = p[1]
    self.update(inputs)
    error = error + self.backPropagate(targets, N, M)
   if i % 100 == 0:
    print('误差 %-.5f' % error)
def demo():
 # 一个演示:教神经网络学习逻辑异或(XOR)------------可以换成你自己的数据试试
 pat = [
  [[0,0], [0]],
  [[0,1], [1]],
  [[1,0], [1]],
  [[1,1], [0]]
 ]
 # 创建一个神经网络:输入层有两个节点、隐藏层有两个节点、输出层有一个节点
 n = NN(2, 2, 1)
 # 用一些模式训练它
 n.train(pat)
 # 测试训练的成果(不要吃惊哦)
 n.test(pat)
 # 看看训练好的权重(当然可以考虑把训练好的权重持久化)
 #n.weights()
if __name__ == '__main__':
 demo()

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
利用Python为iOS10生成图标和截屏
Sep 24 Python
浅析Python中yield关键词的作用与用法
Nov 29 Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 Python
Laravel框架表单验证格式化输出的方法
Sep 25 Python
Python动态声明变量赋值代码实例
Dec 30 Python
新版Pycharm中Matplotlib不会弹出独立的显示窗口的问题
Jun 02 Python
浅谈TensorFlow中读取图像数据的三种方式
Jun 30 Python
python list的index()和find()的实现
Nov 16 Python
python利用opencv实现颜色检测
Feb 23 Python
10个顶级Python实用库推荐
Mar 04 Python
Python的flask接收前台的ajax的post数据和get数据的方法
Apr 12 Python
python垃圾回收机制原理分析
Apr 13 Python
Python 12306抢火车票脚本
Feb 07 #Python
django限制匿名用户访问及重定向的方法实例
Feb 07 #Python
Python用 KNN 进行验证码识别的实现方法
Feb 06 #Python
Python实现的径向基(RBF)神经网络示例
Feb 06 #Python
python实现淘宝秒杀聚划算抢购自动提醒源码
Jun 23 #Python
初探TensorFLow从文件读取图片的四种方式
Feb 06 #Python
用十张图详解TensorFlow数据读取机制(附代码)
Feb 06 #Python
You might like
ThinkPHP自动转义存储富文本编辑器内容导致读取出错的解决方法
2014/08/08 PHP
深入理解PHP中的global
2014/08/19 PHP
php解决DOM乱码的方法示例代码
2016/11/20 PHP
THINKPHP截取中文字符串函数实例代码
2017/03/20 PHP
javascript 写类方式之八
2009/07/05 Javascript
textarea不能通过maxlength属性来限制字数的解决方法
2014/09/01 Javascript
jquery实现在光标位置插入内容的方法
2015/02/05 Javascript
javascript父子页面通讯实例详解
2015/07/17 Javascript
Angular.js与Bootstrap相结合实现手风琴菜单代码
2016/04/13 Javascript
基于Node.js的JavaScript项目构建工具gulp的使用教程
2016/05/20 Javascript
js事件冒泡、事件捕获和阻止默认事件详解
2016/08/04 Javascript
AngularJs Modules详解及示例代码
2016/09/01 Javascript
vue-cli脚手架build目录下utils.js工具配置文件详解
2018/09/14 Javascript
node app 打包工具pkg的具体使用
2019/01/17 Javascript
vue下载excel的实现代码后台用post方法
2019/05/10 Javascript
vue+eslint+vscode配置教程
2019/08/09 Javascript
js实现橱窗展示效果
2020/01/11 Javascript
Vue使用JSEncrypt实现rsa加密及挂载方法
2020/02/07 Javascript
Ant Design Vue table中列超长显示...并加提示语的实例
2020/10/31 Javascript
Python使用matplotlib实现在坐标系中画一个矩形的方法
2015/05/20 Python
Python数据结构之顺序表的实现代码示例
2017/11/15 Python
python编程实现12306的一个小爬虫实例
2017/12/27 Python
Python使用itertools模块实现排列组合功能示例
2018/07/02 Python
深入浅析Python获取对象信息的函数type()、isinstance()、dir()
2018/09/17 Python
对python文件读写的缓冲行为详解
2019/02/13 Python
python读出当前时间精度到秒的代码
2019/07/05 Python
django rest framework vue 实现用户登录详解
2019/07/29 Python
python快速编写单行注释多行注释的方法
2019/07/31 Python
Python selenium环境搭建实现过程解析
2020/09/08 Python
使用Python爬虫爬取小红书完完整整的全过程
2021/01/19 Python
Stuarts London美国/加拿大:世界领先的独立男装零售商之一
2019/03/18 全球购物
我的梦想演讲稿500字
2014/08/21 职场文书
先进教师事迹材料
2014/12/16 职场文书
2015暑期工社会实践报告
2015/07/13 职场文书
CSS3 制作的书本翻页特效
2021/04/13 HTML / CSS
使用logback实现按自己的需求打印日志到自定义的文件里
2021/08/30 Java/Android