BP神经网络原理及Python实现代码


Posted in Python onDecember 18, 2018

本文主要讲如何不依赖TenserFlow等高级API实现一个简单的神经网络来做分类,所有的代码都在下面;在构造的数据(通过程序构造)上做了验证,经过1个小时的训练分类的准确率可以达到97%。

完整的结构化代码见于:链接地址

先来说说原理

网络构造

BP神经网络原理及Python实现代码

上面是一个简单的三层网络;输入层包含节点X1 , X2;隐层包含H1,H2;输出层包含O1。
输入节点的数量要等于输入数据的变量数目。
隐层节点的数量通过经验来确定。
如果只是做分类,输出层一般一个节点就够了。

从输入到输出的过程

1.输入节点的输出等于输入,X1节点输入x1时,输出还是x1.
2. 隐层和输出层的输入I为上层输出的加权求和再加偏置,输出为f(I) , f为激活函数,可以取sigmoid。H1的输出为 sigmoid(w1x1 + w2x2 + b)

误差反向传播的过程

Python实现

构造测试数据

# -*- coding: utf-8 -*-
import numpy as np
from random import random as rdn

'''
说明:我们构造1000条数据,每条数据有三个属性(用a1 , a2 , a3表示)
a1 离散型 取值 1 到 10 , 均匀分布
a2 离散型 取值 1 到 10 , 均匀分布
a3 连续型 取值 1 到 100 , 且符合正态分布 
各属性之间独立。

共2个分类(0 , 1),属性值与类别之间的关系如下,
0 : a1 in [1 , 3] and a2 in [4 , 10] and a3 <= 50
1 : a1 in [1 , 3] and a2 in [4 , 10] and a3 > 50
0 : a1 in [1 , 3] and a2 in [1 , 3] and a3 > 30
1 : a1 in [1 , 3] and a2 in [1 , 3] and a3 <= 30
0 : a1 in [4 , 10] and a2 in [4 , 10] and a3 <= 50
1 : a1 in [4 , 10] and a2 in [4 , 10] and a3 > 50
0 : a1 in [4 , 10] and a2 in [1 , 3] and a3 > 30
1 : a1 in [4 , 10] and a2 in [1 , 3] and a3 <= 30
'''


def genData() :
 #为a3生成符合正态分布的数据
 a3_data = np.random.randn(1000) * 30 + 50
 data = []
 for i in range(1000) :
 #生成a1
 a1 = int(rdn()*10) + 1
 if a1 > 10 :
  a1 = 10
 #生成a2
 a2 = int(rdn()*10) + 1
 if a2 > 10 :
  a2 = 10
 #取a3
 a3 = a3_data[i] 
 #计算这条数据对应的类别
 c_id = 0
 if a1 <= 3 and a2 >= 4 and a3 <= 50 :
  c_id = 0 
 elif a1 <= 3 and a2 >= 4 and a3 > 50 :
  c_id = 1 
 elif a1 <= 3 and a2 < 4 and a3 > 30 :
  c_id = 0
 elif a1 <= 3 and a2 < 4 and a3 <= 30 :
  c_id = 1
 elif a1 > 3 and a2 >= 4 and a3 <= 50 :
  c_id = 0 
 elif a1 > 3 and a2 >= 4 and a3 > 50 :
  c_id = 1 
 elif a1 > 3 and a2 < 4 and a3 > 30 :
  c_id = 0
 elif a1 > 3 and a2 < 4 and a3 <= 30 :
  c_id = 1
 else :
  print('error')
 #拼合成字串
 str_line = str(i) + ',' + str(a1) + ',' + str(a2) + ',' + str(a3) + ',' + str(c_id)
 data.append(str_line)
 return '\n'.join(data)

激活函数

# -*- coding: utf-8 -*-
"""
Created on Sun Dec 2 14:49:31 2018

@author: congpeiqing
"""
import numpy as np

#sigmoid函数的导数为 f(x)*(1-f(x))
def sigmoid(x) :
 return 1/(1 + np.exp(-x))

网络实现

# -*- coding: utf-8 -*-
"""
Created on Sun Dec 2 14:49:31 2018

@author: congpeiqing
"""

from activation_funcs import sigmoid
from random import random

class InputNode(object) :
 def __init__(self , idx) :
 self.idx = idx
 self.output = None
  
 def setInput(self , value) :
 self.output = value
 
 def getOutput(self) :
 return self.output
 
 def refreshParas(self , p1 , p2) :
 pass
 
 
class Neurode(object) :
 def __init__(self , layer_name , idx , input_nodes , activation_func = None , powers = None , bias = None) :
 self.idx = idx 
 self.layer_name = layer_name
 self.input_nodes = input_nodes 
 if activation_func is not None :
  self.activation_func = activation_func
 else :
  #默认取 sigmoid
  self.activation_func = sigmoid
 if powers is not None :
  self.powers = powers
 else :
  self.powers = [random() for i in range(len(self.input_nodes))]
 if bias is not None :
  self.bias = bias
 else :
  self.bias = random()
 self.output = None
  
 def getOutput(self) :
 self.output = self.activation_func(sum(map(lambda x : x[0].getOutput()*x[1] , zip(self.input_nodes, self.powers))) + self.bias)
 return self.output
  
 def refreshParas(self , err , learn_rate) :
 err_add = self.output * (1 - self.output) * err 
 for i in range(len(self.input_nodes)) :
  #调用子节点
  self.input_nodes[i].refreshParas(self.powers[i] * err_add , learn_rate)
  #调节参数
  power_delta = learn_rate * err_add * self.input_nodes[i].output 
  self.powers[i] += power_delta
  bias_delta = learn_rate * err_add
  self.bias += bias_delta
 
 
class SimpleBP(object) :
 def __init__(self , input_node_num , hidden_layer_node_num , trainning_data , test_data) :
 self.input_node_num = input_node_num
 self.input_nodes = [InputNode(i) for i in range(input_node_num)]
 self.hidden_layer_nodes = [Neurode('H' , i , self.input_nodes) for i in range(hidden_layer_node_num)]
 self.output_node = Neurode('O' , 0 , self.hidden_layer_nodes)
 self.trainning_data = trainning_data
 self.test_data = test_data
 
 
 #逐条训练
 def trainByItem(self) :
 cnt = 0
 while True :
  cnt += 1
  learn_rate = 1.0/cnt
  sum_diff = 0.0
  #对于每一条训练数据进行一次训练过程
  for item in self.trainning_data :
  for i in range(self.input_node_num) :
   self.input_nodes[i].setInput(item[i])
  item_output = item[-1]
  nn_output = self.output_node.getOutput()
  #print('nn_output:' , nn_output)
  diff = (item_output-nn_output)
  sum_diff += abs(diff)
  self.output_node.refreshParas(diff , learn_rate)
  #print('refreshedParas')
  #结束条件 
  print(round(sum_diff / len(self.trainning_data) , 4))
  if sum_diff / len(self.trainning_data) < 0.1 :
  break
 
 def getAccuracy(self) :
 cnt = 0
 for item in self.test_data :
  for i in range(self.input_node_num) :
  self.input_nodes[i].setInput(item[i])
  item_output = item[-1]
  nn_output = self.output_node.getOutput()
  if (nn_output > 0.5 and item_output > 0.5) or (nn_output < 0.5 and item_output < 0.5) :
  cnt += 1
 return cnt/(len(self.test_data) + 0.0)

主调流程

# -*- coding: utf-8 -*-
"""
Created on Sun Dec 2 14:49:31 2018

@author: congpeiqing
"""
import os
from SimpleBP import SimpleBP
from GenData import genData

if not os.path.exists('data'):
 os.makedirs('data') 

#构造训练和测试数据
data_file = open('data/trainning_data.dat' , 'w')
data_file.write(genData())
data_file.close()

data_file = open('data/test_data.dat' , 'w')
data_file.write(genData())
data_file.close()


#文件格式:rec_id,attr1_value,attr2_value,attr3_value,class_id
#读取和解析训练数据
trainning_data_file = open('data/trainning_data.dat')
trainning_data = []
for line in trainning_data_file :
 line = line.strip()
 fld_list = line.split(',')
 trainning_data.append(tuple([float(field) for field in fld_list[1:]]))
trainning_data_file.close()

#读取和解析测试数据
test_data_file = open('data/test_data.dat')
test_data = []
for line in test_data_file :
 line = line.strip()
 fld_list = line.split(',')
 test_data.append(tuple([float(field) for field in fld_list[1:]]))
test_data_file.close()


#构造一个二分类网络 输入节点3个,隐层节点10个,输出节点一个
simple_bp = SimpleBP(3 , 10 , trainning_data , test_data)
#训练网络
simple_bp.trainByItem()
#测试分类准确率
print('Accuracy : ' , simple_bp.getAccuracy())
#训练时长比较长,准确率可以达到97%

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用htpasswd实现基本认证授权的例子
Jun 10 Python
利用Python学习RabbitMQ消息队列
Nov 30 Python
Python引用类型和值类型的区别与使用解析
Oct 17 Python
运行django项目指定IP和端口的方法
May 14 Python
详解python中的hashlib模块的使用
Apr 22 Python
总结Python图形用户界面和游戏开发知识点
May 22 Python
微信小程序python用户认证的实现
Jul 29 Python
python 生成器和迭代器的原理解析
Oct 12 Python
如何使用python传入不确定个数参数
Feb 18 Python
Jupyter加载文件的实现方法
Apr 14 Python
python查看矩阵的行列号以及维数方式
May 22 Python
python实现测试工具(二)——简单的ui测试工具
Oct 19 Python
python 执行文件时额外参数获取的实例
Dec 18 #Python
python实现基于信息增益的决策树归纳
Dec 18 #Python
Django实现一对多表模型的跨表查询方法
Dec 18 #Python
Python实现字典排序、按照list中字典的某个key排序的方法示例
Dec 18 #Python
python实现求特征选择的信息增益
Dec 18 #Python
python实现连续图文识别
Dec 18 #Python
Django ManyToManyField 跨越中间表查询的方法
Dec 18 #Python
You might like
PHP关联数组的10个操作技巧
2013/01/21 PHP
ThinkPHP登录功能的实现方法
2014/08/20 PHP
Thinkphp搭建包括JS多语言的多语言项目实现方法
2014/11/24 PHP
YII CLinkPager分页类扩展增加显示共多少页
2016/01/29 PHP
PHP运用foreach神奇的转换数组(实例讲解)
2018/02/01 PHP
PHP7.3.10编译安装教程
2019/10/08 PHP
php的对象传值与引用传值代码实例讲解
2021/02/26 PHP
jQuery EasyUI API 中文文档 - ComboTree组合树
2011/10/11 Javascript
JS 实现获取打开一个界面中输入的值
2013/03/19 Javascript
jquery ajax中使用jsonp的限制解决方法
2013/11/22 Javascript
js中window.open打开一个新的页面
2014/08/10 Javascript
JavaScript实现梯形乘法表的方法
2015/04/25 Javascript
JavaScript高级教程5.6之基本包装类型(详细)
2015/11/23 Javascript
一分钟理解js闭包
2016/05/04 Javascript
浅谈js图片前端预览之filereader和window.URL.createObjectURL
2016/06/30 Javascript
JavaScript制作简单分页插件
2016/09/11 Javascript
通过扫描二维码打开app的实现代码
2016/11/10 Javascript
鼠标点击input,显示瞬间的边框颜色,对之修改与隐藏实例
2016/12/26 Javascript
js编写简单的计时器功能
2017/07/15 Javascript
JS实现点星星消除小游戏
2020/03/24 Javascript
用Python编写简单的微博爬虫
2016/03/04 Python
python中学习K-Means和图片压缩
2017/11/20 Python
python实现简易云音乐播放器
2018/01/04 Python
django缓存配置的几种方法详解
2018/07/16 Python
Python实现深度遍历和广度遍历的方法
2019/01/22 Python
django的分页器Paginator 从django中导入类
2019/07/25 Python
PyTorch中反卷积的用法详解
2019/12/30 Python
Perfume’s Club德国官网:在线购买香水
2019/04/08 全球购物
俄罗斯皮肤健康中心:Pharmacosmetica.ru
2020/02/22 全球购物
师德师风个人反思
2014/04/28 职场文书
高中学校对照检查材料
2014/08/31 职场文书
2015毕业生自我评价范文
2015/03/02 职场文书
2016年小学生教师节广播稿
2015/12/18 职场文书
Vue实现下拉加载更多
2021/05/09 Vue.js
mysql中数据库覆盖导入的几种方式总结
2022/03/25 MySQL
SqlServer常用函数及时间处理小结
2023/05/08 SQL Server