Python使用numpy实现BP神经网络


Posted in Python onMarch 10, 2018

本文完全利用numpy实现一个简单的BP神经网络,由于是做regression而不是classification,因此在这里输出层选取的激励函数就是f(x)=x。BP神经网络的具体原理此处不再介绍。

import numpy as np 
 
class NeuralNetwork(object): 
  def __init__(self, input_nodes, hidden_nodes, output_nodes, learning_rate): 
    # Set number of nodes in input, hidden and output layers.设定输入层、隐藏层和输出层的node数目 
    self.input_nodes = input_nodes 
    self.hidden_nodes = hidden_nodes 
    self.output_nodes = output_nodes 
 
    # Initialize weights,初始化权重和学习速率 
    self.weights_input_to_hidden = np.random.normal(0.0, self.hidden_nodes**-0.5,  
                    ( self.hidden_nodes, self.input_nodes)) 
 
    self.weights_hidden_to_output = np.random.normal(0.0, self.output_nodes**-0.5,  
                    (self.output_nodes, self.hidden_nodes)) 
    self.lr = learning_rate 
     
    # 隐藏层的激励函数为sigmoid函数,Activation function is the sigmoid function 
    self.activation_function = (lambda x: 1/(1 + np.exp(-x))) 
   
  def train(self, inputs_list, targets_list): 
    # Convert inputs list to 2d array 
    inputs = np.array(inputs_list, ndmin=2).T  # 输入向量的shape为 [feature_diemension, 1] 
    targets = np.array(targets_list, ndmin=2).T  
 
    # 向前传播,Forward pass 
    # TODO: Hidden layer 
    hidden_inputs = np.dot(self.weights_input_to_hidden, inputs) # signals into hidden layer 
    hidden_outputs = self.activation_function(hidden_inputs) # signals from hidden layer 
 
     
    # 输出层,输出层的激励函数就是 y = x 
    final_inputs = np.dot(self.weights_hidden_to_output, hidden_outputs) # signals into final output layer 
    final_outputs = final_inputs # signals from final output layer 
     
    ### 反向传播 Backward pass,使用梯度下降对权重进行更新 ### 
     
    # 输出误差 
    # Output layer error is the difference between desired target and actual output. 
    output_errors = (targets_list-final_outputs) 
 
    # 反向传播误差 Backpropagated error 
    # errors propagated to the hidden layer 
    hidden_errors = np.dot(output_errors, self.weights_hidden_to_output)*(hidden_outputs*(1-hidden_outputs)).T 
 
    # 更新权重 Update the weights 
    # 更新隐藏层与输出层之间的权重 update hidden-to-output weights with gradient descent step 
    self.weights_hidden_to_output += output_errors * hidden_outputs.T * self.lr 
    # 更新输入层与隐藏层之间的权重 update input-to-hidden weights with gradient descent step 
    self.weights_input_to_hidden += (inputs * hidden_errors * self.lr).T 
  
  # 进行预测   
  def run(self, inputs_list): 
    # Run a forward pass through the network 
    inputs = np.array(inputs_list, ndmin=2).T 
     
    #### 实现向前传播 Implement the forward pass here #### 
    # 隐藏层 Hidden layer 
    hidden_inputs = np.dot(self.weights_input_to_hidden, inputs) # signals into hidden layer 
    hidden_outputs = self.activation_function(hidden_inputs) # signals from hidden layer 
     
    # 输出层 Output layer 
    final_inputs = np.dot(self.weights_hidden_to_output, hidden_outputs) # signals into final output layer 
    final_outputs = final_inputs # signals from final output layer  
     
    return final_outputs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python根据给定文件返回文件名和扩展名的方法
Mar 27 Python
python实现从网络下载文件并获得文件大小及类型的方法
Apr 28 Python
python实现kNN算法
Dec 20 Python
python脚本实现验证码识别
Jun 07 Python
python 下 CMake 安装配置 OPENCV 4.1.1的方法
Sep 30 Python
CentOS7下安装python3.6.8的教程详解
Jan 03 Python
pytorch 自定义参数不更新方式
Jan 06 Python
python中threading开启关闭线程操作
May 02 Python
PyTorch 中的傅里叶卷积实现示例
Dec 11 Python
Python实现一个论文下载器的过程
Jan 18 Python
python将图片转为矢量图的方法步骤
Mar 30 Python
Python3接口性能测试实例代码
Jun 20 Python
python实现日常记账本小程序
Mar 10 #Python
python实现简单神经网络算法
Mar 10 #Python
TensorFlow saver指定变量的存取
Mar 10 #Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
You might like
PHP聊天室技术
2006/10/09 PHP
PHP安全防范技巧分享
2011/11/03 PHP
php获取本周开始日期和结束日期的方法
2015/03/09 PHP
php文件压缩之PHPZip类用法实例
2015/06/18 PHP
php+js实现点赞功能的示例详解
2020/08/07 PHP
表单项的name命名为submit、reset引起的问题
2007/12/22 Javascript
javascript 面向对象全新理练之数据的封装
2009/12/03 Javascript
在JavaScript并非所有的一切都是对象
2013/04/11 Javascript
url参数中有+、空格、=、%、&、#等特殊符号的问题解决
2013/05/15 Javascript
JavaScript实现弹出子窗口并传值给父窗口
2014/12/18 Javascript
使用Browserify配合jQuery进行编程的超级指南
2015/07/28 Javascript
jQuery position() 函数详解以及jQuery中position函数的应用
2015/12/14 Javascript
JS中frameset框架弹出层实例代码
2016/04/01 Javascript
vue滚动tab跟随切换效果
2020/06/29 Javascript
原生JS实现相邻月份日历
2020/10/13 Javascript
详解Vue的mixin策略
2020/11/19 Vue.js
python实现简单购物商城
2016/05/21 Python
Numpy数据类型转换astype,dtype的方法
2018/06/09 Python
利用python GDAL库读写geotiff格式的遥感影像方法
2018/11/29 Python
解决nohup执行python程序log文件写入不及时的问题
2019/01/14 Python
Python3列表内置方法大全及示例代码小结
2019/05/10 Python
Python手绘可视化工具cutecharts使用实例
2019/12/05 Python
python mock测试的示例
2020/10/19 Python
H5仿微信界面教程(一)
2017/07/05 HTML / CSS
捷克电器和DJ设备网上商店:Electronic-star
2017/07/18 全球购物
美国购买肉、鸭、家禽、鹅肝和熟食网站:D’Artagnan
2018/11/13 全球购物
经济实惠的豪华家具:My-Furniture
2019/03/12 全球购物
公务员综合考察材料
2014/02/01 职场文书
全神贯注教学反思
2014/02/03 职场文书
优秀应届毕业生推荐信
2014/02/18 职场文书
2014-2015学年工作总结
2014/11/27 职场文书
2015年社区教育工作总结
2015/05/13 职场文书
婚宴祝酒词大全
2015/08/10 职场文书
观看《杨善洲》宣传教育片心得体会
2016/01/23 职场文书
2016年学生会感恩节活动总结
2016/04/01 职场文书
详解Redis在SpringBoot工程中的综合应用
2021/10/16 Redis