Python使用numpy实现BP神经网络


Posted in Python onMarch 10, 2018

本文完全利用numpy实现一个简单的BP神经网络,由于是做regression而不是classification,因此在这里输出层选取的激励函数就是f(x)=x。BP神经网络的具体原理此处不再介绍。

import numpy as np 
 
class NeuralNetwork(object): 
  def __init__(self, input_nodes, hidden_nodes, output_nodes, learning_rate): 
    # Set number of nodes in input, hidden and output layers.设定输入层、隐藏层和输出层的node数目 
    self.input_nodes = input_nodes 
    self.hidden_nodes = hidden_nodes 
    self.output_nodes = output_nodes 
 
    # Initialize weights,初始化权重和学习速率 
    self.weights_input_to_hidden = np.random.normal(0.0, self.hidden_nodes**-0.5,  
                    ( self.hidden_nodes, self.input_nodes)) 
 
    self.weights_hidden_to_output = np.random.normal(0.0, self.output_nodes**-0.5,  
                    (self.output_nodes, self.hidden_nodes)) 
    self.lr = learning_rate 
     
    # 隐藏层的激励函数为sigmoid函数,Activation function is the sigmoid function 
    self.activation_function = (lambda x: 1/(1 + np.exp(-x))) 
   
  def train(self, inputs_list, targets_list): 
    # Convert inputs list to 2d array 
    inputs = np.array(inputs_list, ndmin=2).T  # 输入向量的shape为 [feature_diemension, 1] 
    targets = np.array(targets_list, ndmin=2).T  
 
    # 向前传播,Forward pass 
    # TODO: Hidden layer 
    hidden_inputs = np.dot(self.weights_input_to_hidden, inputs) # signals into hidden layer 
    hidden_outputs = self.activation_function(hidden_inputs) # signals from hidden layer 
 
     
    # 输出层,输出层的激励函数就是 y = x 
    final_inputs = np.dot(self.weights_hidden_to_output, hidden_outputs) # signals into final output layer 
    final_outputs = final_inputs # signals from final output layer 
     
    ### 反向传播 Backward pass,使用梯度下降对权重进行更新 ### 
     
    # 输出误差 
    # Output layer error is the difference between desired target and actual output. 
    output_errors = (targets_list-final_outputs) 
 
    # 反向传播误差 Backpropagated error 
    # errors propagated to the hidden layer 
    hidden_errors = np.dot(output_errors, self.weights_hidden_to_output)*(hidden_outputs*(1-hidden_outputs)).T 
 
    # 更新权重 Update the weights 
    # 更新隐藏层与输出层之间的权重 update hidden-to-output weights with gradient descent step 
    self.weights_hidden_to_output += output_errors * hidden_outputs.T * self.lr 
    # 更新输入层与隐藏层之间的权重 update input-to-hidden weights with gradient descent step 
    self.weights_input_to_hidden += (inputs * hidden_errors * self.lr).T 
  
  # 进行预测   
  def run(self, inputs_list): 
    # Run a forward pass through the network 
    inputs = np.array(inputs_list, ndmin=2).T 
     
    #### 实现向前传播 Implement the forward pass here #### 
    # 隐藏层 Hidden layer 
    hidden_inputs = np.dot(self.weights_input_to_hidden, inputs) # signals into hidden layer 
    hidden_outputs = self.activation_function(hidden_inputs) # signals from hidden layer 
     
    # 输出层 Output layer 
    final_inputs = np.dot(self.weights_hidden_to_output, hidden_outputs) # signals into final output layer 
    final_outputs = final_inputs # signals from final output layer  
     
    return final_outputs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python代码制作configure文件示例
Jul 28 Python
Python中使用装饰器时需要注意的一些问题
May 11 Python
Python 专题一 函数的基础知识
Mar 16 Python
python实现分页效果
Oct 25 Python
纯python实现机器学习之kNN算法示例
Mar 01 Python
Python基于辗转相除法求解最大公约数的方法示例
Apr 04 Python
在django中图片上传的格式校验及大小方法
Jul 28 Python
Python常用模块os.path之文件及路径操作方法
Dec 03 Python
浅析Python数字类型和字符串类型的内置方法
Dec 22 Python
Django自关联实现多级联动查询实例
May 19 Python
解决Windows下python和pip命令无法使用的问题
Aug 31 Python
解决TensorFlow训练模型及保存数量限制的问题
Mar 03 Python
python实现日常记账本小程序
Mar 10 #Python
python实现简单神经网络算法
Mar 10 #Python
TensorFlow saver指定变量的存取
Mar 10 #Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
You might like
星际流派综述
2020/03/04 星际争霸
php数组函数序列之array_splice() - 在数组任意位置插入元素
2011/11/07 PHP
PHP 类相关函数的使用详解
2013/05/10 PHP
利用php获取服务器时间的实现代码
2013/06/07 PHP
thinkphp连贯操作实例分析
2014/11/22 PHP
php去除字符串中空字符的常用方法小结
2015/03/17 PHP
PHP检测用户是否关闭浏览器的方法
2016/02/14 PHP
php中的登陆login实例代码
2016/06/20 PHP
父窗口获取弹出子窗口文本框的值
2006/06/27 Javascript
jQuery实现可用于博客的动态滑动菜单
2015/03/09 Javascript
jQuery检测输入的字符串包含的中英文的数量
2015/04/17 Javascript
jQuery 实现评论等级好评差评特效
2016/05/06 Javascript
全面解析Javascript无限添加QQ好友原理
2016/06/15 Javascript
Ubuntu 16.04 64位中搭建Node.js开发环境教程
2016/10/19 Javascript
PHP实现本地图片上传和验证功能
2017/02/27 Javascript
浅谈nodejs中的类定义和继承的套路
2017/07/26 NodeJs
Vue项目实现换肤功能的一种方案分析
2019/08/28 Javascript
微信小程序点击view动态添加样式过程解析
2020/01/21 Javascript
js实现电灯开关效果
2021/01/19 Javascript
python使用正则表达式匹配字符串开头并打印示例
2017/01/11 Python
Python之日期与时间处理模块(date和datetime)
2017/02/16 Python
Python探索之修改Python搜索路径
2017/10/25 Python
Python基于mysql实现学生管理系统
2019/02/21 Python
Python实现的栈、队列、文件目录遍历操作示例
2019/05/06 Python
Python直接赋值、浅拷贝与深度拷贝实例分析
2019/06/18 Python
Django如何简单快速实现PUT、DELETE方法
2019/07/24 Python
使用wxpy实现自动发送微信消息功能
2020/02/28 Python
Python RabbitMQ实现简单的进程间通信示例
2020/07/02 Python
手把手教你用纯css3实现轮播图效果实例
2017/05/04 HTML / CSS
Giglio俄罗斯奢侈品购物网:男士、女士、儿童高级时装
2018/07/27 全球购物
新电JAVA笔试题目
2014/08/31 面试题
学校后勤岗位职责
2014/02/19 职场文书
维修工先进事迹
2014/05/29 职场文书
民主评议党员自我评价材料
2014/09/18 职场文书
中小企业员工手册范本
2015/05/14 职场文书
在Windows Server 2012上安装 .NET Framework 3.5 所遇到的问题
2022/04/29 Servers