BP神经网络原理及Python实现代码


Posted in Python onDecember 18, 2018

本文主要讲如何不依赖TenserFlow等高级API实现一个简单的神经网络来做分类,所有的代码都在下面;在构造的数据(通过程序构造)上做了验证,经过1个小时的训练分类的准确率可以达到97%。

完整的结构化代码见于:链接地址

先来说说原理

网络构造

BP神经网络原理及Python实现代码

上面是一个简单的三层网络;输入层包含节点X1 , X2;隐层包含H1,H2;输出层包含O1。
输入节点的数量要等于输入数据的变量数目。
隐层节点的数量通过经验来确定。
如果只是做分类,输出层一般一个节点就够了。

从输入到输出的过程

1.输入节点的输出等于输入,X1节点输入x1时,输出还是x1.
2. 隐层和输出层的输入I为上层输出的加权求和再加偏置,输出为f(I) , f为激活函数,可以取sigmoid。H1的输出为 sigmoid(w1x1 + w2x2 + b)

误差反向传播的过程

Python实现

构造测试数据

# -*- coding: utf-8 -*-
import numpy as np
from random import random as rdn

'''
说明:我们构造1000条数据,每条数据有三个属性(用a1 , a2 , a3表示)
a1 离散型 取值 1 到 10 , 均匀分布
a2 离散型 取值 1 到 10 , 均匀分布
a3 连续型 取值 1 到 100 , 且符合正态分布 
各属性之间独立。

共2个分类(0 , 1),属性值与类别之间的关系如下,
0 : a1 in [1 , 3] and a2 in [4 , 10] and a3 <= 50
1 : a1 in [1 , 3] and a2 in [4 , 10] and a3 > 50
0 : a1 in [1 , 3] and a2 in [1 , 3] and a3 > 30
1 : a1 in [1 , 3] and a2 in [1 , 3] and a3 <= 30
0 : a1 in [4 , 10] and a2 in [4 , 10] and a3 <= 50
1 : a1 in [4 , 10] and a2 in [4 , 10] and a3 > 50
0 : a1 in [4 , 10] and a2 in [1 , 3] and a3 > 30
1 : a1 in [4 , 10] and a2 in [1 , 3] and a3 <= 30
'''


def genData() :
 #为a3生成符合正态分布的数据
 a3_data = np.random.randn(1000) * 30 + 50
 data = []
 for i in range(1000) :
 #生成a1
 a1 = int(rdn()*10) + 1
 if a1 > 10 :
  a1 = 10
 #生成a2
 a2 = int(rdn()*10) + 1
 if a2 > 10 :
  a2 = 10
 #取a3
 a3 = a3_data[i] 
 #计算这条数据对应的类别
 c_id = 0
 if a1 <= 3 and a2 >= 4 and a3 <= 50 :
  c_id = 0 
 elif a1 <= 3 and a2 >= 4 and a3 > 50 :
  c_id = 1 
 elif a1 <= 3 and a2 < 4 and a3 > 30 :
  c_id = 0
 elif a1 <= 3 and a2 < 4 and a3 <= 30 :
  c_id = 1
 elif a1 > 3 and a2 >= 4 and a3 <= 50 :
  c_id = 0 
 elif a1 > 3 and a2 >= 4 and a3 > 50 :
  c_id = 1 
 elif a1 > 3 and a2 < 4 and a3 > 30 :
  c_id = 0
 elif a1 > 3 and a2 < 4 and a3 <= 30 :
  c_id = 1
 else :
  print('error')
 #拼合成字串
 str_line = str(i) + ',' + str(a1) + ',' + str(a2) + ',' + str(a3) + ',' + str(c_id)
 data.append(str_line)
 return '\n'.join(data)

激活函数

# -*- coding: utf-8 -*-
"""
Created on Sun Dec 2 14:49:31 2018

@author: congpeiqing
"""
import numpy as np

#sigmoid函数的导数为 f(x)*(1-f(x))
def sigmoid(x) :
 return 1/(1 + np.exp(-x))

网络实现

# -*- coding: utf-8 -*-
"""
Created on Sun Dec 2 14:49:31 2018

@author: congpeiqing
"""

from activation_funcs import sigmoid
from random import random

class InputNode(object) :
 def __init__(self , idx) :
 self.idx = idx
 self.output = None
  
 def setInput(self , value) :
 self.output = value
 
 def getOutput(self) :
 return self.output
 
 def refreshParas(self , p1 , p2) :
 pass
 
 
class Neurode(object) :
 def __init__(self , layer_name , idx , input_nodes , activation_func = None , powers = None , bias = None) :
 self.idx = idx 
 self.layer_name = layer_name
 self.input_nodes = input_nodes 
 if activation_func is not None :
  self.activation_func = activation_func
 else :
  #默认取 sigmoid
  self.activation_func = sigmoid
 if powers is not None :
  self.powers = powers
 else :
  self.powers = [random() for i in range(len(self.input_nodes))]
 if bias is not None :
  self.bias = bias
 else :
  self.bias = random()
 self.output = None
  
 def getOutput(self) :
 self.output = self.activation_func(sum(map(lambda x : x[0].getOutput()*x[1] , zip(self.input_nodes, self.powers))) + self.bias)
 return self.output
  
 def refreshParas(self , err , learn_rate) :
 err_add = self.output * (1 - self.output) * err 
 for i in range(len(self.input_nodes)) :
  #调用子节点
  self.input_nodes[i].refreshParas(self.powers[i] * err_add , learn_rate)
  #调节参数
  power_delta = learn_rate * err_add * self.input_nodes[i].output 
  self.powers[i] += power_delta
  bias_delta = learn_rate * err_add
  self.bias += bias_delta
 
 
class SimpleBP(object) :
 def __init__(self , input_node_num , hidden_layer_node_num , trainning_data , test_data) :
 self.input_node_num = input_node_num
 self.input_nodes = [InputNode(i) for i in range(input_node_num)]
 self.hidden_layer_nodes = [Neurode('H' , i , self.input_nodes) for i in range(hidden_layer_node_num)]
 self.output_node = Neurode('O' , 0 , self.hidden_layer_nodes)
 self.trainning_data = trainning_data
 self.test_data = test_data
 
 
 #逐条训练
 def trainByItem(self) :
 cnt = 0
 while True :
  cnt += 1
  learn_rate = 1.0/cnt
  sum_diff = 0.0
  #对于每一条训练数据进行一次训练过程
  for item in self.trainning_data :
  for i in range(self.input_node_num) :
   self.input_nodes[i].setInput(item[i])
  item_output = item[-1]
  nn_output = self.output_node.getOutput()
  #print('nn_output:' , nn_output)
  diff = (item_output-nn_output)
  sum_diff += abs(diff)
  self.output_node.refreshParas(diff , learn_rate)
  #print('refreshedParas')
  #结束条件 
  print(round(sum_diff / len(self.trainning_data) , 4))
  if sum_diff / len(self.trainning_data) < 0.1 :
  break
 
 def getAccuracy(self) :
 cnt = 0
 for item in self.test_data :
  for i in range(self.input_node_num) :
  self.input_nodes[i].setInput(item[i])
  item_output = item[-1]
  nn_output = self.output_node.getOutput()
  if (nn_output > 0.5 and item_output > 0.5) or (nn_output < 0.5 and item_output < 0.5) :
  cnt += 1
 return cnt/(len(self.test_data) + 0.0)

主调流程

# -*- coding: utf-8 -*-
"""
Created on Sun Dec 2 14:49:31 2018

@author: congpeiqing
"""
import os
from SimpleBP import SimpleBP
from GenData import genData

if not os.path.exists('data'):
 os.makedirs('data') 

#构造训练和测试数据
data_file = open('data/trainning_data.dat' , 'w')
data_file.write(genData())
data_file.close()

data_file = open('data/test_data.dat' , 'w')
data_file.write(genData())
data_file.close()


#文件格式:rec_id,attr1_value,attr2_value,attr3_value,class_id
#读取和解析训练数据
trainning_data_file = open('data/trainning_data.dat')
trainning_data = []
for line in trainning_data_file :
 line = line.strip()
 fld_list = line.split(',')
 trainning_data.append(tuple([float(field) for field in fld_list[1:]]))
trainning_data_file.close()

#读取和解析测试数据
test_data_file = open('data/test_data.dat')
test_data = []
for line in test_data_file :
 line = line.strip()
 fld_list = line.split(',')
 test_data.append(tuple([float(field) for field in fld_list[1:]]))
test_data_file.close()


#构造一个二分类网络 输入节点3个,隐层节点10个,输出节点一个
simple_bp = SimpleBP(3 , 10 , trainning_data , test_data)
#训练网络
simple_bp.trainByItem()
#测试分类准确率
print('Accuracy : ' , simple_bp.getAccuracy())
#训练时长比较长,准确率可以达到97%

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过邮件服务器端口发送邮件的方法
Apr 30 Python
python使用PyGame模块播放声音的方法
May 20 Python
Python中使用插入排序算法的简单分析与代码示例
May 04 Python
python 二分查找和快速排序实例详解
Oct 13 Python
Python编程pygame模块实现移动的小车示例代码
Jan 03 Python
Python使用sort和class实现的多级排序功能示例
Aug 15 Python
Python2和Python3中urllib库中urlencode的使用注意事项
Nov 26 Python
Python类中的魔法方法之 __slots__原理解析
Aug 26 Python
python enumerate内置函数用法总结
Jan 07 Python
使用Python内置模块与函数进行不同进制的数的转换
Apr 26 Python
M1芯片安装python3.9.1的实现
Feb 02 Python
告别网页搜索!教你用python实现一款属于自己的翻译词典软件
Jun 03 Python
python 执行文件时额外参数获取的实例
Dec 18 #Python
python实现基于信息增益的决策树归纳
Dec 18 #Python
Django实现一对多表模型的跨表查询方法
Dec 18 #Python
Python实现字典排序、按照list中字典的某个key排序的方法示例
Dec 18 #Python
python实现求特征选择的信息增益
Dec 18 #Python
python实现连续图文识别
Dec 18 #Python
Django ManyToManyField 跨越中间表查询的方法
Dec 18 #Python
You might like
BBS(php &amp; mysql)完整版(八)
2006/10/09 PHP
CodeIgniter 完美解决URL含有中文字符串
2016/05/13 PHP
php基于PDO连接MSSQL示例DEMO
2016/07/13 PHP
Javascript实现CheckBox的全选与取消全选的代码
2010/07/20 Javascript
jQuery1.3.2 升级到jQuery1.4.4需要修改的地方
2011/01/06 Javascript
node.js中的fs.truncate方法使用说明
2014/12/15 Javascript
jQuery随机密码生成的方法
2015/03/09 Javascript
Javascript中arguments和arguments.callee的区别浅析
2015/04/24 Javascript
JS实现可调整倒计时间代码分享
2015/08/18 Javascript
微信小程序去哪里找 小程序到底如何使用(附小程序名单)
2017/01/09 Javascript
解决URL地址中的中文乱码问题的办法
2017/02/10 Javascript
详解node+express+ejs+bootstrap构建项目
2017/09/27 Javascript
vue使用xe-utils函数库的具体方法
2018/03/06 Javascript
详谈js的变量提升以及使用方法
2018/10/06 Javascript
微信小程序使用canvas的画图操作示例
2019/01/18 Javascript
jquery实现动态创建form并提交的方法示例
2019/05/27 jQuery
JavaScript indexOf()原理及使用方法详解
2020/07/09 Javascript
[01:36:57]【09DOTA2第一视角】小骷髅
2014/04/16 DOTA
windows系统中python使用rar命令压缩多个文件夹示例
2014/05/06 Python
python实现list元素按关键字相加减的方法示例
2017/06/09 Python
python书籍信息爬虫实例
2018/03/19 Python
使用Python编写Prometheus监控的方法
2018/10/15 Python
Python3并发写文件与Python对比
2019/11/20 Python
python扫描线填充算法详解
2020/02/19 Python
澳洲女装时尚在线:Blue Bungalow
2018/05/05 全球购物
BLACKMORES澳洲官网:澳大利亚排名第一的保健品牌
2018/09/27 全球购物
Currentbody美国/加拿大:美容仪专家
2020/03/09 全球购物
师范生自荐信
2013/10/27 职场文书
公务员职务工作的自我评价
2013/11/01 职场文书
大一军训感言
2014/01/09 职场文书
九年级政治教学反思
2014/02/06 职场文书
数学与统计学院学生个人职业生涯规划书
2014/02/10 职场文书
趣味比赛活动方案
2014/02/15 职场文书
Go语言切片前或中间插入项与内置copy()函数详解
2021/04/27 Golang
python如何利用traceback获取详细的异常信息
2021/06/05 Python
JS精髓原型链继承及构造函数继承问题纠正
2022/06/16 Javascript