Python实现的NN神经网络算法完整示例


Posted in Python onJune 19, 2018

本文实例讲述了Python实现的NN神经网络算法。分享给大家供大家参考,具体如下:

参考自Github开源代码:https://github.com/dennybritz/nn-from-scratch

运行环境

  • Pyhton3
  • numpy(科学计算包)
  • matplotlib(画图所需,不画图可不必)
  • sklearn(人工智能包,生成数据使用)

计算过程

Python实现的NN神经网络算法完整示例

输入样例

none

代码实现

# -*- coding:utf-8 -*-
#!python3
__author__ = 'Wsine'
import numpy as np
import sklearn
import sklearn.datasets
import sklearn.linear_model
import matplotlib.pyplot as plt
import matplotlib
import operator
import time
def createData(dim=200, cnoise=0.20):
  """
  输出:数据集, 对应的类别标签
  描述:生成一个数据集和对应的类别标签
  """
  np.random.seed(0)
  X, y = sklearn.datasets.make_moons(dim, noise=cnoise)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
  #plt.show()
  return X, y
def plot_decision_boundary(pred_func, X, y):
  """
  输入:边界函数, 数据集, 类别标签
  描述:绘制决策边界(画图用)
  """
  # 设置最小最大值, 加上一点外边界
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # 根据最小最大值和一个网格距离生成整个网格
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # 对整个网格预测边界值
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # 绘制边界和数据集的点
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)
def calculate_loss(model, X, y):
  """
  输入:训练模型, 数据集, 类别标签
  输出:误判的概率
  描述:计算整个模型的性能
  """
  W1, b1, W2, b2 = model['W1'], model['b1'], model['W2'], model['b2']
  # 正向传播来计算预测的分类值
  z1 = X.dot(W1) + b1
  a1 = np.tanh(z1)
  z2 = a1.dot(W2) + b2
  exp_scores = np.exp(z2)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
  # 计算误判概率
  corect_logprobs = -np.log(probs[range(num_examples), y])
  data_loss = np.sum(corect_logprobs)
  # 加入正则项修正错误(可选)
  data_loss += reg_lambda/2 * (np.sum(np.square(W1)) + np.sum(np.square(W2)))
  return 1./num_examples * data_loss
def predict(model, x):
  """
  输入:训练模型, 预测向量
  输出:判决类别
  描述:预测类别属于(0 or 1)
  """
  W1, b1, W2, b2 = model['W1'], model['b1'], model['W2'], model['b2']
  # 正向传播计算
  z1 = x.dot(W1) + b1
  a1 = np.tanh(z1)
  z2 = a1.dot(W2) + b2
  exp_scores = np.exp(z2)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
  return np.argmax(probs, axis=1)
def initParameter(X):
  """
  输入:数据集
  描述:初始化神经网络算法的参数
     必须初始化为全局函数!
     这里需要手动设置!
  """
  global num_examples
  num_examples = len(X) # 训练集的大小
  global nn_input_dim
  nn_input_dim = 2 # 输入层维数
  global nn_output_dim
  nn_output_dim = 2 # 输出层维数
  # 梯度下降参数
  global epsilon
  epsilon = 0.01 # 梯度下降学习步长
  global reg_lambda
  reg_lambda = 0.01 # 修正的指数
def build_model(X, y, nn_hdim, num_passes=20000, print_loss=False):
  """
  输入:数据集, 类别标签, 隐藏层层数, 迭代次数, 是否输出误判率
  输出:神经网络模型
  描述:生成一个指定层数的神经网络模型
  """
  # 根据维度随机初始化参数
  np.random.seed(0)
  W1 = np.random.randn(nn_input_dim, nn_hdim) / np.sqrt(nn_input_dim)
  b1 = np.zeros((1, nn_hdim))
  W2 = np.random.randn(nn_hdim, nn_output_dim) / np.sqrt(nn_hdim)
  b2 = np.zeros((1, nn_output_dim))
  model = {}
  # 梯度下降
  for i in range(0, num_passes):
    # 正向传播
    z1 = X.dot(W1) + b1
    a1 = np.tanh(z1) # 激活函数使用tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
    z2 = a1.dot(W2) + b2
    exp_scores = np.exp(z2) # 原始归一化
    probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
    # 后向传播
    delta3 = probs
    delta3[range(num_examples), y] -= 1
    dW2 = (a1.T).dot(delta3)
    db2 = np.sum(delta3, axis=0, keepdims=True)
    delta2 = delta3.dot(W2.T) * (1 - np.power(a1, 2))
    dW1 = np.dot(X.T, delta2)
    db1 = np.sum(delta2, axis=0)
    # 加入修正项
    dW2 += reg_lambda * W2
    dW1 += reg_lambda * W1
    # 更新梯度下降参数
    W1 += -epsilon * dW1
    b1 += -epsilon * db1
    W2 += -epsilon * dW2
    b2 += -epsilon * db2
    # 更新模型
    model = { 'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}
    # 一定迭代次数后输出当前误判率
    if print_loss and i % 1000 == 0:
      print("Loss after iteration %i: %f" % (i, calculate_loss(model, X, y)))
  plot_decision_boundary(lambda x: predict(model, x), X, y)
  plt.title("Decision Boundary for hidden layer size %d" % nn_hdim)
  #plt.show()
  return model
def main():
  dataSet, labels = createData(200, 0.20)
  initParameter(dataSet)
  nnModel = build_model(dataSet, labels, 3, print_loss=False)
  print("Loss is %f" % calculate_loss(nnModel, dataSet, labels))
if __name__ == '__main__':
  start = time.clock()
  main()
  end = time.clock()
  print('finish all in %s' % str(end - start))
  plt.show()

输出样例

Loss is 0.071316
finish all in 7.221354361552228

Python实现的NN神经网络算法完整示例

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
动态规划之矩阵连乘问题Python实现方法
Nov 27 Python
influx+grafana自定义python采集数据和一些坑的总结
Sep 17 Python
用Python逐行分析文件方法
Jan 28 Python
python实现连连看辅助之图像识别延伸
Jul 17 Python
python Matplotlib底图中鼠标滑过显示隐藏内容的实例代码
Jul 31 Python
PyTorch中常用的激活函数的方法示例
Aug 20 Python
Python socket 套接字实现通信详解
Aug 27 Python
使用PyTorch将文件夹下的图片分为训练集和验证集实例
Jan 08 Python
解析pip安装第三方库但PyCharm中却无法识别的问题及PyCharm安装第三方库的方法教程
Mar 10 Python
基于python tkinter的点名小程序功能的实例代码
Aug 22 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
May 28 Python
Python使用Beautiful Soup(BS4)库解析HTML和XML
Jun 05 Python
python中的二维列表实例详解
Jun 19 #Python
Tensorflow中使用tfrecord方式读取数据的方法
Jun 19 #Python
python3实现SMTP发送邮件详细教程
Jun 19 #Python
Python SVM(支持向量机)实现方法完整示例
Jun 19 #Python
Tensorflow使用tfrecord输入数据格式
Jun 19 #Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 #Python
python如何爬取个性签名
Jun 19 #Python
You might like
一个php作的文本留言本的例子(六)
2006/10/09 PHP
基于python发送邮件的乱码问题的解决办法
2013/04/25 PHP
php通过pecl方式安装扩展的实例讲解
2018/02/02 PHP
javascript实现的鼠标链接提示效果生成器代码
2007/06/28 Javascript
获取当前网页document.url location.href区别总结
2008/05/10 Javascript
过虑特殊字符输入的js代码
2010/08/05 Javascript
最短的javascript:地址栏载入脚本代码
2011/10/13 Javascript
完美解决IE低版本不支持call与apply的问题
2013/12/05 Javascript
一个不错的js html页面倒计时可精确到秒
2014/10/22 Javascript
基于Bootstrap使用jQuery实现输入框组input-group的添加与删除
2016/05/03 Javascript
JavaScript中的跨浏览器事件操作的基本方法整理
2016/05/20 Javascript
JQuery在循环中绑定事件的问题详解
2016/06/02 Javascript
教你如何在Node.js中使用jQuery
2016/08/28 Javascript
Angularjs的Controller间通信机制实例分析
2016/11/07 Javascript
Bootstrap3 图片(响应式图片&图片形状)
2017/01/04 Javascript
基于daterangepicker日历插件使用参数注意的问题
2017/08/10 Javascript
JS从非数组对象转数组的方法小结
2018/03/26 Javascript
解决axios发送post请求返回400状态码的问题
2018/08/11 Javascript
微信小程序实现的canvas合成图片功能示例
2019/05/03 Javascript
小程序根据手机机型设置自定义底部导航距离
2019/06/04 Javascript
浅谈一种让小程序支持JSX语法的新思路
2019/06/16 Javascript
通过原生vue添加滚动加载更多功能
2019/11/21 Javascript
[48:24]完美世界DOTA2联赛循环赛LBZS vs Forest 第一场 10月30日
2020/10/31 DOTA
Python 元类使用说明
2009/12/18 Python
利用Python将时间或时间间隔转为ISO 8601格式方法示例
2017/09/05 Python
Python使用try except处理程序异常的三种常用方法分析
2018/09/05 Python
Python脚本完成post接口测试的实例
2018/12/17 Python
python3使用flask编写注册post接口的方法
2018/12/28 Python
python 将有序数组转换为二叉树的方法
2019/03/26 Python
Python参数传递实现过程及原理详解
2020/05/14 Python
应届生高等护理求职信
2013/10/12 职场文书
新闻编辑专业毕业自荐书范文
2014/02/05 职场文书
干部个人对照检查材料
2014/08/25 职场文书
生产设备维护保养制度
2015/08/06 职场文书
乡镇团代会开幕词
2016/03/04 职场文书
Python中的turtle画箭头,矩形,五角星
2022/03/16 Python