Python实现的人工神经网络算法示例【基于反向传播算法】


Posted in Python onNovember 11, 2017

本文实例讲述了Python实现的人工神经网络算法。分享给大家供大家参考,具体如下:

注意:本程序使用Python3编写,额外需要安装numpy工具包用于矩阵运算,未测试python2是否可以运行。

本程序实现了《机器学习》书中所述的反向传播算法训练人工神经网络,理论部分请参考我的读书笔记。

在本程序中,目标函数是由一个输入x和两个输出y组成,
x是在范围【-3.14, 3.14】之间随机生成的实数,而两个y值分别对应 y1 = sin(x),y2 = 1。

随机生成一万份训练样例,经过网络的学习训练后,再用随机生成的五份测试数据验证训练结果。

调节算法的学习速率,以及隐藏层个数、隐藏层大小,训练新的网络,可以观察到参数对于学习结果的影响。

算法代码如下:

#!usr/bin/env python3
# -*- coding:utf-8 -*-
import numpy as np
import math
# definition of sigmoid funtion
# numpy.exp work for arrays.
def sigmoid(x):
  return 1 / (1 + np.exp(-x))
# definition of sigmoid derivative funtion
# input must be sigmoid function's result
def sigmoid_output_to_derivative(result):
  return result*(1-result)
# init training set
def getTrainingSet(nameOfSet):
  setDict = {
    "sin": getSinSet(),
    }
  return setDict[nameOfSet]
def getSinSet():
  x = 6.2 * np.random.rand(1) - 3.14
  x = x.reshape(1,1)
  # y = np.array([5 *x]).reshape(1,1)
  # y = np.array([math.sin(x)]).reshape(1,1)
  y = np.array([math.sin(x),1]).reshape(1,2)
  return x, y
def getW(synapse, delta):
  resultList = []
  # 遍历隐藏层每个隐藏单元对每个输出的权值,比如8个隐藏单元,每个隐藏单元对两个输出各有2个权值
  for i in range(synapse.shape[0]):
    resultList.append(
      (synapse[i,:] * delta).sum()
      )
  resultArr = np.array(resultList).reshape(1, synapse.shape[0])
  return resultArr
def getT(delta, layer):
  result = np.dot(layer.T, delta)
  return result
def backPropagation(trainingExamples, etah, input_dim, output_dim, hidden_dim, hidden_num):
  # 可行条件
  if hidden_num < 1:
    print("隐藏层数不得小于1")
    return
  # 初始化网络权重矩阵,这个是核心
  synapseList = []
  # 输入层与隐含层1
  synapseList.append(2*np.random.random((input_dim,hidden_dim)) - 1)
  # 隐含层1与隐含层2, 2->3,,,,,,n-1->n
  for i in range(hidden_num-1):
    synapseList.append(2*np.random.random((hidden_dim,hidden_dim)) - 1)
  # 隐含层n与输出层
  synapseList.append(2*np.random.random((hidden_dim,output_dim)) - 1)
  iCount = 0
  lastErrorMax = 99999
  # while True:
  for i in range(10000):
    errorMax = 0
    for x, y in trainingExamples:
      iCount += 1
      layerList = []
      # 正向传播
      layerList.append(
        sigmoid(np.dot(x,synapseList[0]))
        )
      for j in range(hidden_num):
        layerList.append(
          sigmoid(np.dot(layerList[-1],synapseList[j+1]))
          )
      # 对于网络中的每个输出单元k,计算它的误差项
      deltaList = []
      layerOutputError = y - layerList[-1]
      # 收敛条件
      errorMax = layerOutputError.sum() if layerOutputError.sum() > errorMax else errorMax
      deltaK = sigmoid_output_to_derivative(layerList[-1]) * layerOutputError
      deltaList.append(deltaK)
      iLength = len(synapseList)
      for j in range(hidden_num):
        w = getW(synapseList[iLength - 1 - j], deltaList[j])
        delta = sigmoid_output_to_derivative(layerList[iLength - 2 - j]) * w
        deltaList.append(delta)
      # 更新每个网络权值w(ji)
      for j in range(len(synapseList)-1, 0, -1):
        t = getT(deltaList[iLength - 1 -j], layerList[j-1])
        synapseList[j] = synapseList[j] + etah * t
      t = getT(deltaList[-1], x)
      synapseList[0] = synapseList[0] + etah * t
    print("最大输出误差:")
    print(errorMax)
    if abs(lastErrorMax - errorMax) < 0.0001:
      print("收敛了")
      print("####################")
      break
    lastErrorMax = errorMax
  # 测试训练好的网络
  for i in range(5):
    xTest, yReal = getSinSet()
    layerTmp = sigmoid(np.dot(xTest,synapseList[0]))
    for j in range(1, len(synapseList), 1):
      layerTmp = sigmoid(np.dot(layerTmp,synapseList[j]))
    yTest = layerTmp
    print("x:")
    print(xTest)
    print("实际的y:")
    print(yReal)
    print("神经元网络输出的y:")
    print(yTest)
    print("最终输出误差:")
    print(np.abs(yReal - yTest))
    print("#####################")
  print("迭代次数:")
  print(iCount)
if __name__ == '__main__':
  import datetime
  tStart = datetime.datetime.now()
  # 使用什么样的训练样例
  nameOfSet = "sin"
  x, y = getTrainingSet(nameOfSet)
  # setting of parameters
  # 这里设置了学习速率。
  etah = 0.01
  # 隐藏层数
  hidden_num = 2
  # 网络输入层的大小
  input_dim = x.shape[1]
  # 隐含层的大小
  hidden_dim = 100
  # 输出层的大小
  output_dim = y.shape[1]
  # 构建训练样例
  trainingExamples = []
  for i in range(10000):
    x, y = getTrainingSet(nameOfSet)
    trainingExamples.append((x, y))
  # 开始用反向传播算法训练网络
  backPropagation(trainingExamples, etah, input_dim, output_dim, hidden_dim, hidden_num)
  tEnd = datetime.datetime.now()
  print("time cost:")
  print(tEnd - tStart)

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python素数检测实例分析
Jun 15 Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 Python
python如何统计序列中元素
Jul 31 Python
TensorFlow实现简单卷积神经网络
May 24 Python
python用pandas数据加载、存储与文件格式的实例
Dec 07 Python
基于python的socket实现单机五子棋到双人对战
Mar 24 Python
python 解决cv2绘制中文乱码问题
Dec 23 Python
python 对任意数据和曲线进行拟合并求出函数表达式的三种解决方案
Feb 18 Python
Python关键字及可变参数*args,**kw原理解析
Apr 04 Python
在CentOS7下安装Python3教程解析
Jul 09 Python
python如何导出微信公众号文章方法详解
Aug 31 Python
Python可变集合和不可变集合的构造方法大全
Dec 06 Python
python中使用正则表达式的后向搜索肯定模式(推荐)
Nov 11 #Python
python基础练习之几个简单的游戏
Nov 10 #Python
Python实现购物车功能的方法分析
Nov 10 #Python
Python实现的单向循环链表功能示例
Nov 10 #Python
Python3中的列表,元组,字典,字符串相关知识小结
Nov 10 #Python
浅谈Python处理PDF的方法
Nov 10 #Python
django开发教程之利用缓存文件进行页面缓存的方法
Nov 10 #Python
You might like
php使用正则过滤js脚本代码实例
2014/05/10 PHP
php 判断网页是否是utf8编码的方法
2014/06/06 PHP
Thinkphp将二维数组变为标签适用的一维数组方法总结
2014/10/30 PHP
PHP正则表达式之捕获组与非捕获组
2015/11/06 PHP
javascript基础的动画教程,直观易懂
2007/01/10 Javascript
js 返回时间戳所对应的具体时间
2010/07/20 Javascript
解决jquery的.animate()函数在IE6下的问题
2010/12/03 Javascript
去掉gridPanel表头全选框的小例子
2013/07/18 Javascript
node.js中Socket.IO的进阶使用技巧
2014/11/04 Javascript
分享使用AngularJS创建应用的5个框架
2015/12/05 Javascript
基于javascript实现根据身份证号码识别性别和年龄
2016/01/22 Javascript
jQuery通过ajax请求php遍历json数组到table中的代码(推荐)
2016/06/12 Javascript
再谈Javascript中的基本类型和引用类型(推荐)
2016/07/01 Javascript
JavaScript实现简单的星星评分效果
2017/05/18 Javascript
react-redux中connect()方法详细解析
2017/05/27 Javascript
基于jquery实现多级菜单效果
2017/07/25 jQuery
vue中的scope使用详解
2017/10/29 Javascript
jQuery实现每隔一段时间自动更换样式的方法分析
2018/05/03 jQuery
微信小程序页面间传值与页面取值操作实例分析
2019/04/30 Javascript
解密Python中的描述符(descriptor)
2015/06/03 Python
python实现txt文件格式转换为arff格式
2018/05/31 Python
python3 读取Excel表格中的数据
2018/10/16 Python
django 使用全局搜索功能的实例详解
2019/07/18 Python
IronPython连接MySQL的方法步骤
2019/12/27 Python
pycharm第三方库安装失败的问题及解决经验分享
2020/05/09 Python
护理自荐信
2013/10/22 职场文书
行政经理岗位职责
2013/11/09 职场文书
2014年公司庆元旦活动方案
2014/03/05 职场文书
新闻传媒系求职信范文
2014/04/19 职场文书
工会换届选举方案
2014/05/21 职场文书
公司周年庆典标语
2014/10/07 职场文书
财务会计岗位职责
2015/02/03 职场文书
在JavaScript中如何使用宏详解
2021/05/06 Javascript
简单介绍Python的第三方库yaml
2021/06/18 Python
解决SpringBoot跨域的三种方式
2021/06/26 Java/Android
关于Oracle12C默认用户名system密码不正确的解决方案
2021/10/16 Oracle