Python实现的人工神经网络算法示例【基于反向传播算法】


Posted in Python onNovember 11, 2017

本文实例讲述了Python实现的人工神经网络算法。分享给大家供大家参考,具体如下:

注意:本程序使用Python3编写,额外需要安装numpy工具包用于矩阵运算,未测试python2是否可以运行。

本程序实现了《机器学习》书中所述的反向传播算法训练人工神经网络,理论部分请参考我的读书笔记。

在本程序中,目标函数是由一个输入x和两个输出y组成,
x是在范围【-3.14, 3.14】之间随机生成的实数,而两个y值分别对应 y1 = sin(x),y2 = 1。

随机生成一万份训练样例,经过网络的学习训练后,再用随机生成的五份测试数据验证训练结果。

调节算法的学习速率,以及隐藏层个数、隐藏层大小,训练新的网络,可以观察到参数对于学习结果的影响。

算法代码如下:

#!usr/bin/env python3
# -*- coding:utf-8 -*-
import numpy as np
import math
# definition of sigmoid funtion
# numpy.exp work for arrays.
def sigmoid(x):
  return 1 / (1 + np.exp(-x))
# definition of sigmoid derivative funtion
# input must be sigmoid function's result
def sigmoid_output_to_derivative(result):
  return result*(1-result)
# init training set
def getTrainingSet(nameOfSet):
  setDict = {
    "sin": getSinSet(),
    }
  return setDict[nameOfSet]
def getSinSet():
  x = 6.2 * np.random.rand(1) - 3.14
  x = x.reshape(1,1)
  # y = np.array([5 *x]).reshape(1,1)
  # y = np.array([math.sin(x)]).reshape(1,1)
  y = np.array([math.sin(x),1]).reshape(1,2)
  return x, y
def getW(synapse, delta):
  resultList = []
  # 遍历隐藏层每个隐藏单元对每个输出的权值,比如8个隐藏单元,每个隐藏单元对两个输出各有2个权值
  for i in range(synapse.shape[0]):
    resultList.append(
      (synapse[i,:] * delta).sum()
      )
  resultArr = np.array(resultList).reshape(1, synapse.shape[0])
  return resultArr
def getT(delta, layer):
  result = np.dot(layer.T, delta)
  return result
def backPropagation(trainingExamples, etah, input_dim, output_dim, hidden_dim, hidden_num):
  # 可行条件
  if hidden_num < 1:
    print("隐藏层数不得小于1")
    return
  # 初始化网络权重矩阵,这个是核心
  synapseList = []
  # 输入层与隐含层1
  synapseList.append(2*np.random.random((input_dim,hidden_dim)) - 1)
  # 隐含层1与隐含层2, 2->3,,,,,,n-1->n
  for i in range(hidden_num-1):
    synapseList.append(2*np.random.random((hidden_dim,hidden_dim)) - 1)
  # 隐含层n与输出层
  synapseList.append(2*np.random.random((hidden_dim,output_dim)) - 1)
  iCount = 0
  lastErrorMax = 99999
  # while True:
  for i in range(10000):
    errorMax = 0
    for x, y in trainingExamples:
      iCount += 1
      layerList = []
      # 正向传播
      layerList.append(
        sigmoid(np.dot(x,synapseList[0]))
        )
      for j in range(hidden_num):
        layerList.append(
          sigmoid(np.dot(layerList[-1],synapseList[j+1]))
          )
      # 对于网络中的每个输出单元k,计算它的误差项
      deltaList = []
      layerOutputError = y - layerList[-1]
      # 收敛条件
      errorMax = layerOutputError.sum() if layerOutputError.sum() > errorMax else errorMax
      deltaK = sigmoid_output_to_derivative(layerList[-1]) * layerOutputError
      deltaList.append(deltaK)
      iLength = len(synapseList)
      for j in range(hidden_num):
        w = getW(synapseList[iLength - 1 - j], deltaList[j])
        delta = sigmoid_output_to_derivative(layerList[iLength - 2 - j]) * w
        deltaList.append(delta)
      # 更新每个网络权值w(ji)
      for j in range(len(synapseList)-1, 0, -1):
        t = getT(deltaList[iLength - 1 -j], layerList[j-1])
        synapseList[j] = synapseList[j] + etah * t
      t = getT(deltaList[-1], x)
      synapseList[0] = synapseList[0] + etah * t
    print("最大输出误差:")
    print(errorMax)
    if abs(lastErrorMax - errorMax) < 0.0001:
      print("收敛了")
      print("####################")
      break
    lastErrorMax = errorMax
  # 测试训练好的网络
  for i in range(5):
    xTest, yReal = getSinSet()
    layerTmp = sigmoid(np.dot(xTest,synapseList[0]))
    for j in range(1, len(synapseList), 1):
      layerTmp = sigmoid(np.dot(layerTmp,synapseList[j]))
    yTest = layerTmp
    print("x:")
    print(xTest)
    print("实际的y:")
    print(yReal)
    print("神经元网络输出的y:")
    print(yTest)
    print("最终输出误差:")
    print(np.abs(yReal - yTest))
    print("#####################")
  print("迭代次数:")
  print(iCount)
if __name__ == '__main__':
  import datetime
  tStart = datetime.datetime.now()
  # 使用什么样的训练样例
  nameOfSet = "sin"
  x, y = getTrainingSet(nameOfSet)
  # setting of parameters
  # 这里设置了学习速率。
  etah = 0.01
  # 隐藏层数
  hidden_num = 2
  # 网络输入层的大小
  input_dim = x.shape[1]
  # 隐含层的大小
  hidden_dim = 100
  # 输出层的大小
  output_dim = y.shape[1]
  # 构建训练样例
  trainingExamples = []
  for i in range(10000):
    x, y = getTrainingSet(nameOfSet)
    trainingExamples.append((x, y))
  # 开始用反向传播算法训练网络
  backPropagation(trainingExamples, etah, input_dim, output_dim, hidden_dim, hidden_num)
  tEnd = datetime.datetime.now()
  print("time cost:")
  print(tEnd - tStart)

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python实现备份文件实例
Sep 16 Python
Java Web开发过程中登陆模块的验证码的实现方式总结
May 25 Python
Python第三方库的安装方法总结
Jun 06 Python
Python入门_浅谈数据结构的4种基本类型
May 16 Python
Python实现的knn算法示例
Jun 14 Python
Python中几种属性访问的区别与用法详解
Oct 10 Python
python实现朴素贝叶斯算法
Nov 19 Python
python实现邮件发送功能
Aug 10 Python
opencv转换颜色空间更改图片背景
Aug 20 Python
Python 中pandas索引切片读取数据缺失数据处理问题
Oct 09 Python
python绘制分布折线图的示例
Sep 24 Python
Django给表单添加honeypot验证增加安全性
May 06 Python
python中使用正则表达式的后向搜索肯定模式(推荐)
Nov 11 #Python
python基础练习之几个简单的游戏
Nov 10 #Python
Python实现购物车功能的方法分析
Nov 10 #Python
Python实现的单向循环链表功能示例
Nov 10 #Python
Python3中的列表,元组,字典,字符串相关知识小结
Nov 10 #Python
浅谈Python处理PDF的方法
Nov 10 #Python
django开发教程之利用缓存文件进行页面缓存的方法
Nov 10 #Python
You might like
ZF等常用php框架中存在的问题
2008/01/10 PHP
PHP_NETWORK_GETADDRESSES: GETADDRINFO FAILED问题解决办法
2014/05/04 PHP
Laravel中Trait的用法实例详解
2016/03/16 PHP
php使用Header函数,PHP_AUTH_PW和PHP_AUTH_USER做用户验证
2016/05/04 PHP
PHP实现数据库的增删查改功能及完整代码
2018/04/18 PHP
PHP基于DateTime类解决Unix时间戳与日期互转问题【针对1970年前及2038年后时间戳】
2018/06/13 PHP
基于laravel Request的所有方法详解
2019/09/29 PHP
网页禁用右键实现代码(JavaScript代码)
2009/10/29 Javascript
setTimeout与setInterval在不同浏览器下的差异
2010/01/24 Javascript
javascript 基础篇4 window对象,DOM
2012/03/14 Javascript
JavaScript 高级篇之DOM文档,简单封装及调用、动态添加、删除样式(六)
2012/04/07 Javascript
jQuery ReferenceError: $ is not defined 错误的处理办法
2013/05/10 Javascript
Javascript中Event属性搜集整理
2013/09/17 Javascript
jquery实现多行文字图片滚动效果示例代码
2014/10/10 Javascript
jQuery对象的链式操作用法分析
2016/05/10 Javascript
javascript之IE版本检测超简单方法
2016/08/20 Javascript
在vue中,v-for的索引index在html中的使用方法
2018/03/06 Javascript
解决Antd Table表头加Icon和气泡提示的坑
2020/11/17 Javascript
node使用async_hooks模块进行请求追踪
2021/01/28 Javascript
[45:44]完美世界DOTA2联赛PWL S2 FTD vs PXG 第一场 11.27
2020/12/01 DOTA
python编程实现希尔排序
2017/04/13 Python
pyspark 读取csv文件创建DataFrame的两种方法
2018/06/07 Python
Python3数据库操作包pymysql的操作方法
2018/07/16 Python
详解基于python-django框架的支付宝支付案例
2019/09/23 Python
Python 中pandas索引切片读取数据缺失数据处理问题
2019/10/09 Python
使用IDLE的Python shell窗口实例详解
2019/11/19 Python
python 线性回归分析模型检验标准--拟合优度详解
2020/02/24 Python
英国第一蛋白粉品牌:Myprotein
2016/09/14 全球购物
德国PC硬件网站:CASEKING
2016/10/20 全球购物
新奥尔良珠宝:Mignon Faget
2020/11/23 全球购物
物流专业毕业生推荐信范文
2013/11/18 职场文书
环保倡议书
2014/04/14 职场文书
自我推荐信范文
2014/05/09 职场文书
公民代理授权委托书
2014/09/24 职场文书
地道战观后感300字
2015/06/04 职场文书
MySQL命令行操作时的编码问题详解
2021/04/14 MySQL