基于python的BP神经网络及异或实现过程解析


Posted in Python onSeptember 30, 2019

BP神经网络是最简单的神经网络模型了,三层能够模拟非线性函数效果。

基于python的BP神经网络及异或实现过程解析

难点:

  • 如何确定初始化参数?
  • 如何确定隐含层节点数量?
  • 迭代多少次?如何更快收敛?
  • 如何获得全局最优解?
'''
neural networks 

created on 2019.9.24
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt

'''
neural network 
'''
class NeuralNetwork:

 def __init__(self, layer_nums, iter_num = 10000, batch_size = 1):
  self.__ILI = 0;
  self.__HLI = 1;
  self.__OLI = 2;
  self.__TLN = 3;

  if len(layer_nums) != self.__TLN:
   raise Exception("layer_nums length must be 3");

  self.__layer_nums = layer_nums; #array [layer0_num, layer1_num ...layerN_num]
  self.__iter_num = iter_num;
  self.__batch_size = batch_size;
 
 def train(self, X, Y):
  X = numpy.array(X);
  Y = numpy.array(Y);

  self.L = [];
  #initialize parameters
  self.__weight = [];
  self.__bias = [];
  self.__step_len = [];
  for layer_index in range(1, self.__TLN):
   self.__weight.append(numpy.random.rand(self.__layer_nums[layer_index - 1], self.__layer_nums[layer_index]) * 2 - 1.0);
   self.__bias.append(numpy.random.rand(self.__layer_nums[layer_index]) * 2 - 1.0);
   self.__step_len.append(0.3);

  logging.info("bias:%s" % (self.__bias));
  logging.info("weight:%s" % (self.__weight));

  for iter_index in range(self.__iter_num):
   sample_index = random.randint(0, len(X) - 1);
   logging.debug("-----round:%s, select sample %s-----" % (iter_index, sample_index));
   output = self.forward_pass(X[sample_index]);
   g = (-output[2] + Y[sample_index]) * self.activation_drive(output[2]);
   logging.debug("g:%s" % (g));
   for j in range(len(output[1])):
    self.__weight[1][j] += self.__step_len[1] * g * output[1][j];
   self.__bias[1] -= self.__step_len[1] * g;

   e = [];
   for i in range(self.__layer_nums[self.__HLI]):
    e.append(numpy.dot(g, self.__weight[1][i]) * self.activation_drive(output[1][i]));
   e = numpy.array(e);
   logging.debug("e:%s" % (e));
   for j in range(len(output[0])):
    self.__weight[0][j] += self.__step_len[0] * e * output[0][j];
   self.__bias[0] -= self.__step_len[0] * e;

   l = 0;
   for i in range(len(X)):
    predictions = self.forward_pass(X[i])[2];
    l += 0.5 * numpy.sum((predictions - Y[i]) ** 2);
   l /= len(X);
   self.L.append(l);

   logging.debug("bias:%s" % (self.__bias));
   logging.debug("weight:%s" % (self.__weight));
   logging.debug("loss:%s" % (l));
  logging.info("bias:%s" % (self.__bias));
  logging.info("weight:%s" % (self.__weight));
  logging.info("L:%s" % (self.L));
 
 def activation(self, z):
  return (1.0 / (1.0 + numpy.exp(-z)));

 def activation_drive(self, y):
  return y * (1.0 - y);

 def forward_pass(self, x):
  data = numpy.copy(x);
  result = [];
  result.append(data);
  for layer_index in range(self.__TLN - 1):
   data = self.activation(numpy.dot(data, self.__weight[layer_index]) - self.__bias[layer_index]);
   result.append(data);
  return numpy.array(result);

 def predict(self, x):
  return self.forward_pass(x)[self.__OLI];


def main():
 logging.basicConfig(level = logging.INFO,
   format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
   datefmt = '%a, %d %b %Y %H:%M:%S');
   
 logging.info("trainning begin.");
 nn = NeuralNetwork([2, 2, 1]);
 X = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]]);
 Y = numpy.array([0, 1, 0, 1]);
 nn.train(X, Y);

 logging.info("trainning end. predict begin.");
 for x in X:
  print(x, nn.predict(x));

 plt.plot(nn.L)
 plt.show();

if __name__ == "__main__":
 main();

具体收敛效果

基于python的BP神经网络及异或实现过程解析

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之总结参数的传递
Oct 10 Python
python使用smtplib模块通过gmail实现邮件发送的方法
May 08 Python
Python的randrange()方法使用教程
May 15 Python
批量获取及验证HTTP代理的Python脚本
Apr 23 Python
Python基于numpy灵活定义神经网络结构的方法
Aug 19 Python
Python使用django框架实现多人在线匿名聊天的小程序
Nov 29 Python
Python给你的头像加上圣诞帽
Jan 04 Python
对python mayavi三维绘图的实现详解
Jan 08 Python
在windows下使用python进行串口通讯的方法
Jul 02 Python
Python模块相关知识点小结
Mar 09 Python
python实现在列表中查找某个元素的下标示例
Nov 16 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 Python
Window10下python3.7 安装与卸载教程图解
Sep 30 #Python
Python检查图片是否损坏及图片类型是否正确过程详解
Sep 30 #Python
Python3 合并二叉树的实现
Sep 30 #Python
自适应线性神经网络Adaline的python实现详解
Sep 30 #Python
softmax及python实现过程解析
Sep 30 #Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
You might like
用PHP控制用户的浏览器--ob*函数的使用说明
2007/03/16 PHP
解析php时间戳与日期的转换
2013/06/06 PHP
php网站地图生成类示例
2014/01/13 PHP
PHP中加密解密函数与DES加密解密实例
2014/10/17 PHP
PHP生成随机码的思路与方法实例探索
2019/04/11 PHP
js中精确计算加法和减法示例
2014/03/28 Javascript
jquery通过select列表选择框对表格数据进行过滤示例
2014/05/07 Javascript
jQuery中filter()方法用法实例
2015/01/06 Javascript
js判断是否按下了Shift键的方法
2015/01/27 Javascript
ECMAScript 5中的属性描述符详解
2015/03/02 Javascript
学习JavaScript设计模式(策略模式)
2015/11/26 Javascript
详解Angular的8个主要构造块
2017/06/20 Javascript
关于vue-router的那些事儿
2018/05/23 Javascript
vue select选择框数据变化监听方法
2018/08/24 Javascript
Vue 实现手动刷新组件的方法
2019/02/19 Javascript
详解vue中使用protobuf踩坑记
2019/05/07 Javascript
js根据后缀判断文件文件类型的代码
2020/05/09 Javascript
[45:50]完美世界DOTA2联赛PWL S3 CPG vs Forest 第二场 12.16
2020/12/17 DOTA
Python中的jquery PyQuery库使用小结
2014/05/13 Python
Python实现提取文章摘要的方法
2015/04/21 Python
Python实现ssh批量登录并执行命令
2016/10/25 Python
Python中turtle库的使用实例
2019/09/09 Python
在keras中model.fit_generator()和model.fit()的区别说明
2020/06/17 Python
详解Python遍历列表时删除元素的正确做法
2021/01/07 Python
京东全球售:直邮香港,澳门,台湾,美国,澳大利亚等地区
2017/09/24 全球购物
举例说明类变量和实例变量的区别
2016/06/30 面试题
小学生环保倡议书
2014/05/15 职场文书
学校运动会报道稿
2014/09/23 职场文书
世界遗产的导游词
2015/02/13 职场文书
手机销售员岗位职责
2015/04/11 职场文书
学校百日安全活动总结
2015/05/07 职场文书
2016大学生暑期社会实践心得体会
2016/01/14 职场文书
phpQuery解析HTML乱码问题(补充官网未列出的乱码解决方案)
2021/04/01 PHP
pytorch训练神经网络爆内存的解决方案
2021/05/22 Python
【海涛DOTA解说】EVE女子战队独家录像加ZSMJ神牛两连发
2022/04/01 DOTA
《仙剑客栈2》第一弹正式宣传片公开 年内发售
2022/04/07 其他游戏