基于python的BP神经网络及异或实现过程解析


Posted in Python onSeptember 30, 2019

BP神经网络是最简单的神经网络模型了,三层能够模拟非线性函数效果。

基于python的BP神经网络及异或实现过程解析

难点:

  • 如何确定初始化参数?
  • 如何确定隐含层节点数量?
  • 迭代多少次?如何更快收敛?
  • 如何获得全局最优解?
'''
neural networks 

created on 2019.9.24
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt

'''
neural network 
'''
class NeuralNetwork:

 def __init__(self, layer_nums, iter_num = 10000, batch_size = 1):
  self.__ILI = 0;
  self.__HLI = 1;
  self.__OLI = 2;
  self.__TLN = 3;

  if len(layer_nums) != self.__TLN:
   raise Exception("layer_nums length must be 3");

  self.__layer_nums = layer_nums; #array [layer0_num, layer1_num ...layerN_num]
  self.__iter_num = iter_num;
  self.__batch_size = batch_size;
 
 def train(self, X, Y):
  X = numpy.array(X);
  Y = numpy.array(Y);

  self.L = [];
  #initialize parameters
  self.__weight = [];
  self.__bias = [];
  self.__step_len = [];
  for layer_index in range(1, self.__TLN):
   self.__weight.append(numpy.random.rand(self.__layer_nums[layer_index - 1], self.__layer_nums[layer_index]) * 2 - 1.0);
   self.__bias.append(numpy.random.rand(self.__layer_nums[layer_index]) * 2 - 1.0);
   self.__step_len.append(0.3);

  logging.info("bias:%s" % (self.__bias));
  logging.info("weight:%s" % (self.__weight));

  for iter_index in range(self.__iter_num):
   sample_index = random.randint(0, len(X) - 1);
   logging.debug("-----round:%s, select sample %s-----" % (iter_index, sample_index));
   output = self.forward_pass(X[sample_index]);
   g = (-output[2] + Y[sample_index]) * self.activation_drive(output[2]);
   logging.debug("g:%s" % (g));
   for j in range(len(output[1])):
    self.__weight[1][j] += self.__step_len[1] * g * output[1][j];
   self.__bias[1] -= self.__step_len[1] * g;

   e = [];
   for i in range(self.__layer_nums[self.__HLI]):
    e.append(numpy.dot(g, self.__weight[1][i]) * self.activation_drive(output[1][i]));
   e = numpy.array(e);
   logging.debug("e:%s" % (e));
   for j in range(len(output[0])):
    self.__weight[0][j] += self.__step_len[0] * e * output[0][j];
   self.__bias[0] -= self.__step_len[0] * e;

   l = 0;
   for i in range(len(X)):
    predictions = self.forward_pass(X[i])[2];
    l += 0.5 * numpy.sum((predictions - Y[i]) ** 2);
   l /= len(X);
   self.L.append(l);

   logging.debug("bias:%s" % (self.__bias));
   logging.debug("weight:%s" % (self.__weight));
   logging.debug("loss:%s" % (l));
  logging.info("bias:%s" % (self.__bias));
  logging.info("weight:%s" % (self.__weight));
  logging.info("L:%s" % (self.L));
 
 def activation(self, z):
  return (1.0 / (1.0 + numpy.exp(-z)));

 def activation_drive(self, y):
  return y * (1.0 - y);

 def forward_pass(self, x):
  data = numpy.copy(x);
  result = [];
  result.append(data);
  for layer_index in range(self.__TLN - 1):
   data = self.activation(numpy.dot(data, self.__weight[layer_index]) - self.__bias[layer_index]);
   result.append(data);
  return numpy.array(result);

 def predict(self, x):
  return self.forward_pass(x)[self.__OLI];


def main():
 logging.basicConfig(level = logging.INFO,
   format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
   datefmt = '%a, %d %b %Y %H:%M:%S');
   
 logging.info("trainning begin.");
 nn = NeuralNetwork([2, 2, 1]);
 X = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]]);
 Y = numpy.array([0, 1, 0, 1]);
 nn.train(X, Y);

 logging.info("trainning end. predict begin.");
 for x in X:
  print(x, nn.predict(x));

 plt.plot(nn.L)
 plt.show();

if __name__ == "__main__":
 main();

具体收敛效果

基于python的BP神经网络及异或实现过程解析

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python多重继承实例
Oct 11 Python
python使用smtplib模块通过gmail实现邮件发送的方法
May 08 Python
Python中的defaultdict与__missing__()使用介绍
Feb 03 Python
Python基于辗转相除法求解最大公约数的方法示例
Apr 04 Python
python一行sql太长折成多行并且有多个参数的方法
Jul 19 Python
使用Python创建简单的HTTP服务器的方法步骤
Apr 26 Python
python实现大量图片重命名
Mar 23 Python
通过python扫描二维码/条形码并打印数据
Nov 14 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 Python
Numpy中np.random.rand()和np.random.randn() 用法和区别详解
Oct 23 Python
python语言实现贪吃蛇游戏
Nov 13 Python
用Python制作音乐海报
Jan 26 Python
Window10下python3.7 安装与卸载教程图解
Sep 30 #Python
Python检查图片是否损坏及图片类型是否正确过程详解
Sep 30 #Python
Python3 合并二叉树的实现
Sep 30 #Python
自适应线性神经网络Adaline的python实现详解
Sep 30 #Python
softmax及python实现过程解析
Sep 30 #Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
You might like
jquery+thinkphp实现跨域抓取数据的方法
2016/10/15 PHP
两个DIV等高的JS的实现代码
2007/12/23 Javascript
JS面向对象编程 for Cookie
2010/09/19 Javascript
JQuery文本框高亮显示插件代码
2011/04/02 Javascript
jquery实现每个数字上都带进度条的幻灯片
2013/02/20 Javascript
页面加载完成后再执行JS的jquery写法以及区别说明
2014/02/22 Javascript
javascript查询字符串参数的方法
2015/01/28 Javascript
快速学习AngularJs HTTP响应拦截器
2015/12/31 Javascript
jquery datatable服务端分页
2016/08/31 Javascript
vue中渐进过渡效果实现
2016/10/27 Javascript
详解vue-cli快速构建项目以及引入bootstrap、jq
2017/05/26 Javascript
详解VUE中v-bind的基本用法
2017/07/13 Javascript
Bootstrap实现下拉菜单多级联动
2017/11/23 Javascript
浅谈Node.js 子进程与应用场景
2018/01/24 Javascript
在vue项目中使用Nprogress.js进度条的方法
2018/01/31 Javascript
使用vue2实现购物车和地址选配功能
2018/03/29 Javascript
解决vue中监听input只能输入数字及英文或者其他情况的问题
2018/08/30 Javascript
vue中的router-view组件的使用教程
2018/10/23 Javascript
JS函数参数的传递与同名参数实例分析
2020/03/16 Javascript
Vue 基于 vuedraggable 实现选中、拖拽、排序效果
2020/05/18 Javascript
python读取Android permission文件
2013/11/01 Python
部署Python的框架下的web app的详细教程
2015/04/30 Python
在Python中操作日期和时间之gmtime()方法的使用
2015/05/22 Python
python简单实现旋转图片的方法
2015/05/30 Python
win10下tensorflow和matplotlib安装教程
2018/09/19 Python
一行Python代码过滤标点符号等特殊字符
2019/08/12 Python
在Python中用GDAL实现矢量对栅格的切割实例
2020/03/11 Python
LTD Commodities:礼品,独特发现,家居装饰,家用器皿
2017/08/11 全球购物
仓库门卫岗位职责
2013/12/22 职场文书
养殖人员的创业计划书范文
2013/12/26 职场文书
财务学生的职业生涯发展
2014/02/11 职场文书
团队经理竞聘书
2014/03/31 职场文书
2014幼儿园卫生保健工作总结
2014/12/05 职场文书
2015年乡镇卫生院工作总结
2015/04/22 职场文书
一次项目中Thinkphp绕过禁用函数的实战记录
2021/11/17 PHP
java后台调用接口及处理跨域问题的解决
2022/03/24 Java/Android