Python使用numpy实现BP神经网络


Posted in Python onMarch 10, 2018

本文完全利用numpy实现一个简单的BP神经网络,由于是做regression而不是classification,因此在这里输出层选取的激励函数就是f(x)=x。BP神经网络的具体原理此处不再介绍。

import numpy as np 
 
class NeuralNetwork(object): 
  def __init__(self, input_nodes, hidden_nodes, output_nodes, learning_rate): 
    # Set number of nodes in input, hidden and output layers.设定输入层、隐藏层和输出层的node数目 
    self.input_nodes = input_nodes 
    self.hidden_nodes = hidden_nodes 
    self.output_nodes = output_nodes 
 
    # Initialize weights,初始化权重和学习速率 
    self.weights_input_to_hidden = np.random.normal(0.0, self.hidden_nodes**-0.5,  
                    ( self.hidden_nodes, self.input_nodes)) 
 
    self.weights_hidden_to_output = np.random.normal(0.0, self.output_nodes**-0.5,  
                    (self.output_nodes, self.hidden_nodes)) 
    self.lr = learning_rate 
     
    # 隐藏层的激励函数为sigmoid函数,Activation function is the sigmoid function 
    self.activation_function = (lambda x: 1/(1 + np.exp(-x))) 
   
  def train(self, inputs_list, targets_list): 
    # Convert inputs list to 2d array 
    inputs = np.array(inputs_list, ndmin=2).T  # 输入向量的shape为 [feature_diemension, 1] 
    targets = np.array(targets_list, ndmin=2).T  
 
    # 向前传播,Forward pass 
    # TODO: Hidden layer 
    hidden_inputs = np.dot(self.weights_input_to_hidden, inputs) # signals into hidden layer 
    hidden_outputs = self.activation_function(hidden_inputs) # signals from hidden layer 
 
     
    # 输出层,输出层的激励函数就是 y = x 
    final_inputs = np.dot(self.weights_hidden_to_output, hidden_outputs) # signals into final output layer 
    final_outputs = final_inputs # signals from final output layer 
     
    ### 反向传播 Backward pass,使用梯度下降对权重进行更新 ### 
     
    # 输出误差 
    # Output layer error is the difference between desired target and actual output. 
    output_errors = (targets_list-final_outputs) 
 
    # 反向传播误差 Backpropagated error 
    # errors propagated to the hidden layer 
    hidden_errors = np.dot(output_errors, self.weights_hidden_to_output)*(hidden_outputs*(1-hidden_outputs)).T 
 
    # 更新权重 Update the weights 
    # 更新隐藏层与输出层之间的权重 update hidden-to-output weights with gradient descent step 
    self.weights_hidden_to_output += output_errors * hidden_outputs.T * self.lr 
    # 更新输入层与隐藏层之间的权重 update input-to-hidden weights with gradient descent step 
    self.weights_input_to_hidden += (inputs * hidden_errors * self.lr).T 
  
  # 进行预测   
  def run(self, inputs_list): 
    # Run a forward pass through the network 
    inputs = np.array(inputs_list, ndmin=2).T 
     
    #### 实现向前传播 Implement the forward pass here #### 
    # 隐藏层 Hidden layer 
    hidden_inputs = np.dot(self.weights_input_to_hidden, inputs) # signals into hidden layer 
    hidden_outputs = self.activation_function(hidden_inputs) # signals from hidden layer 
     
    # 输出层 Output layer 
    final_inputs = np.dot(self.weights_hidden_to_output, hidden_outputs) # signals into final output layer 
    final_outputs = final_inputs # signals from final output layer  
     
    return final_outputs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python装饰器decorator介绍
Nov 21 Python
分享几道你可能遇到的python面试题
Jul 24 Python
Python操作csv文件实例详解
Jul 31 Python
python删除字符串中指定字符的方法
Aug 13 Python
python 实现批量xls文件转csv文件的方法
Oct 23 Python
pandas dataframe的合并实现(append, merge, concat)
Jun 24 Python
Python之time模块的时间戳,时间字符串格式化与转换方法(13位时间戳)
Aug 12 Python
解决python replace函数替换无效问题
Jan 18 Python
从训练好的tensorflow模型中打印训练变量实例
Jan 20 Python
TensorFlow实现checkpoint文件转换为pb文件
Feb 10 Python
Python 线性回归分析以及评价指标详解
Apr 02 Python
Pycharm 如何一键加引号的方法步骤
Feb 05 Python
python实现日常记账本小程序
Mar 10 #Python
python实现简单神经网络算法
Mar 10 #Python
TensorFlow saver指定变量的存取
Mar 10 #Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
You might like
php数组函数序列之array_flip() 将数组键名与值对调
2011/11/07 PHP
解析PHP SPL标准库的用法(遍历目录,查找固定条件的文件)
2013/06/18 PHP
Yii2 hasOne(), hasMany() 实现三表关联的方法(两种)
2017/02/15 PHP
php删除二维数组中的重复值方法
2018/03/12 PHP
ASP.NET jQuery 实例3 (在TextBox里面阻止复制、剪切和粘贴事件)
2012/01/13 Javascript
jQuery+PHP打造滑动开关效果
2014/12/16 Javascript
jQuery插件开发的五种形态小结
2015/03/04 Javascript
nodejs实现HTTPS发起POST请求
2015/04/23 NodeJs
Extjs4.0 ComboBox如何实现三级联动
2016/05/11 Javascript
简单实现js点击展开二级菜单功能
2017/05/16 Javascript
vue-resouce设置请求头的三种方法
2017/09/12 Javascript
vue-router判断页面未登录自动跳转到登录页的方法示例
2018/11/04 Javascript
js删除数组中某几项的方法总结
2019/01/16 Javascript
javascript中数组的常用算法深入分析
2019/03/12 Javascript
JavaScript展开操作符(Spread operator)详解
2019/07/20 Javascript
Vue实现简单计算器案例
2020/02/25 Javascript
[02:38]2018DOTA2亚洲邀请赛赛前采访-VGJ.T
2018/04/03 DOTA
Python爬虫抓取手机APP的传输数据
2016/01/22 Python
利用python画一颗心的方法示例
2017/01/31 Python
浅谈Django REST Framework限速
2017/12/12 Python
Python生成短uuid的方法实例详解
2018/05/29 Python
pycharm下查看python的变量类型和变量内容的方法
2018/06/26 Python
咖啡为什么会有酸味?你喝到的咖啡為什麼是酸的?
2021/03/17 冲泡冲煮
解决margin 外边距合并问题
2019/07/03 HTML / CSS
思想政治教育专业个人求职信范文
2013/12/20 职场文书
新店开张活动方案
2014/08/24 职场文书
委托书的写法
2014/08/30 职场文书
群众路线教育实践活动学习心得体会
2014/10/30 职场文书
员工辞职信范文
2015/03/02 职场文书
小升初自荐信怎么写
2015/03/26 职场文书
公司员工奖惩制度
2015/08/04 职场文书
创业计划书之水果店
2019/07/18 职场文书
职业规划从高考志愿专业选择开始
2019/08/08 职场文书
Nginx服务器添加Systemd自定义服务过程解析
2021/03/31 Servers
《帝国时代4》赛季预告 新增内容编译器可创造地图
2022/04/03 其他游戏