Python实现的NN神经网络算法完整示例


Posted in Python onJune 19, 2018

本文实例讲述了Python实现的NN神经网络算法。分享给大家供大家参考,具体如下:

参考自Github开源代码:https://github.com/dennybritz/nn-from-scratch

运行环境

  • Pyhton3
  • numpy(科学计算包)
  • matplotlib(画图所需,不画图可不必)
  • sklearn(人工智能包,生成数据使用)

计算过程

Python实现的NN神经网络算法完整示例

输入样例

none

代码实现

# -*- coding:utf-8 -*-
#!python3
__author__ = 'Wsine'
import numpy as np
import sklearn
import sklearn.datasets
import sklearn.linear_model
import matplotlib.pyplot as plt
import matplotlib
import operator
import time
def createData(dim=200, cnoise=0.20):
  """
  输出:数据集, 对应的类别标签
  描述:生成一个数据集和对应的类别标签
  """
  np.random.seed(0)
  X, y = sklearn.datasets.make_moons(dim, noise=cnoise)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
  #plt.show()
  return X, y
def plot_decision_boundary(pred_func, X, y):
  """
  输入:边界函数, 数据集, 类别标签
  描述:绘制决策边界(画图用)
  """
  # 设置最小最大值, 加上一点外边界
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # 根据最小最大值和一个网格距离生成整个网格
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # 对整个网格预测边界值
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # 绘制边界和数据集的点
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)
def calculate_loss(model, X, y):
  """
  输入:训练模型, 数据集, 类别标签
  输出:误判的概率
  描述:计算整个模型的性能
  """
  W1, b1, W2, b2 = model['W1'], model['b1'], model['W2'], model['b2']
  # 正向传播来计算预测的分类值
  z1 = X.dot(W1) + b1
  a1 = np.tanh(z1)
  z2 = a1.dot(W2) + b2
  exp_scores = np.exp(z2)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
  # 计算误判概率
  corect_logprobs = -np.log(probs[range(num_examples), y])
  data_loss = np.sum(corect_logprobs)
  # 加入正则项修正错误(可选)
  data_loss += reg_lambda/2 * (np.sum(np.square(W1)) + np.sum(np.square(W2)))
  return 1./num_examples * data_loss
def predict(model, x):
  """
  输入:训练模型, 预测向量
  输出:判决类别
  描述:预测类别属于(0 or 1)
  """
  W1, b1, W2, b2 = model['W1'], model['b1'], model['W2'], model['b2']
  # 正向传播计算
  z1 = x.dot(W1) + b1
  a1 = np.tanh(z1)
  z2 = a1.dot(W2) + b2
  exp_scores = np.exp(z2)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
  return np.argmax(probs, axis=1)
def initParameter(X):
  """
  输入:数据集
  描述:初始化神经网络算法的参数
     必须初始化为全局函数!
     这里需要手动设置!
  """
  global num_examples
  num_examples = len(X) # 训练集的大小
  global nn_input_dim
  nn_input_dim = 2 # 输入层维数
  global nn_output_dim
  nn_output_dim = 2 # 输出层维数
  # 梯度下降参数
  global epsilon
  epsilon = 0.01 # 梯度下降学习步长
  global reg_lambda
  reg_lambda = 0.01 # 修正的指数
def build_model(X, y, nn_hdim, num_passes=20000, print_loss=False):
  """
  输入:数据集, 类别标签, 隐藏层层数, 迭代次数, 是否输出误判率
  输出:神经网络模型
  描述:生成一个指定层数的神经网络模型
  """
  # 根据维度随机初始化参数
  np.random.seed(0)
  W1 = np.random.randn(nn_input_dim, nn_hdim) / np.sqrt(nn_input_dim)
  b1 = np.zeros((1, nn_hdim))
  W2 = np.random.randn(nn_hdim, nn_output_dim) / np.sqrt(nn_hdim)
  b2 = np.zeros((1, nn_output_dim))
  model = {}
  # 梯度下降
  for i in range(0, num_passes):
    # 正向传播
    z1 = X.dot(W1) + b1
    a1 = np.tanh(z1) # 激活函数使用tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
    z2 = a1.dot(W2) + b2
    exp_scores = np.exp(z2) # 原始归一化
    probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
    # 后向传播
    delta3 = probs
    delta3[range(num_examples), y] -= 1
    dW2 = (a1.T).dot(delta3)
    db2 = np.sum(delta3, axis=0, keepdims=True)
    delta2 = delta3.dot(W2.T) * (1 - np.power(a1, 2))
    dW1 = np.dot(X.T, delta2)
    db1 = np.sum(delta2, axis=0)
    # 加入修正项
    dW2 += reg_lambda * W2
    dW1 += reg_lambda * W1
    # 更新梯度下降参数
    W1 += -epsilon * dW1
    b1 += -epsilon * db1
    W2 += -epsilon * dW2
    b2 += -epsilon * db2
    # 更新模型
    model = { 'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}
    # 一定迭代次数后输出当前误判率
    if print_loss and i % 1000 == 0:
      print("Loss after iteration %i: %f" % (i, calculate_loss(model, X, y)))
  plot_decision_boundary(lambda x: predict(model, x), X, y)
  plt.title("Decision Boundary for hidden layer size %d" % nn_hdim)
  #plt.show()
  return model
def main():
  dataSet, labels = createData(200, 0.20)
  initParameter(dataSet)
  nnModel = build_model(dataSet, labels, 3, print_loss=False)
  print("Loss is %f" % calculate_loss(nnModel, dataSet, labels))
if __name__ == '__main__':
  start = time.clock()
  main()
  end = time.clock()
  print('finish all in %s' % str(end - start))
  plt.show()

输出样例

Loss is 0.071316
finish all in 7.221354361552228

Python实现的NN神经网络算法完整示例

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python 网络编程起步(Socket发送消息)
Sep 06 Python
python中wx将图标显示在右下角的脚本代码
Mar 08 Python
详解Python发送邮件实例
Jan 10 Python
Python脚本实现自动将数据库备份到 Dropbox
Feb 06 Python
Python实现按中文排序的方法示例
Apr 25 Python
Python列表与元组的异同详解
Jul 02 Python
在Python中使用filter去除列表中值为假及空字符串的例子
Nov 18 Python
python的json中方法及jsonpath模块用法分析
Dec 06 Python
Python3 xml.etree.ElementTree支持的XPath语法详解
Mar 06 Python
Python request中文乱码问题解决方案
Sep 17 Python
基于python+selenium自动健康打卡的实现代码
Jan 13 Python
详解Python中的for循环
Apr 30 Python
python中的二维列表实例详解
Jun 19 #Python
Tensorflow中使用tfrecord方式读取数据的方法
Jun 19 #Python
python3实现SMTP发送邮件详细教程
Jun 19 #Python
Python SVM(支持向量机)实现方法完整示例
Jun 19 #Python
Tensorflow使用tfrecord输入数据格式
Jun 19 #Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 #Python
python如何爬取个性签名
Jun 19 #Python
You might like
PHP的FTP学习(二)
2006/10/09 PHP
php快速排序原理与实现方法分析
2016/05/26 PHP
laravel5使用freetds连接sql server的方法
2018/12/07 PHP
JS应用之禁止抓屏、复制、打印
2008/02/21 Javascript
潜说js对象和数组
2011/05/25 Javascript
关于html+ashx开发中几个问题的解决方法
2011/07/18 Javascript
仅IE支持clearAttributes/mergeAttributes方法使用介绍
2012/05/04 Javascript
单元选择合并变色示例代码
2014/05/26 Javascript
显示今天的日期js代码(阳历和农历)
2014/09/30 Javascript
基于jQuery Circlr插件实现产品图片360度旋转
2015/09/20 Javascript
Jquery组件easyUi实现表单验证示例
2016/08/23 Javascript
深入理解JS实现快速排序和去重
2016/10/17 Javascript
JavaScript计算值然后把值嵌入到html中的实现方法
2016/10/29 Javascript
JavaScript 最佳实践:帮你提升代码质量
2016/12/03 Javascript
简单实现jQuery弹幕效果
2017/05/06 jQuery
vue和webpack打包项目相对路径修改的方法
2018/06/15 Javascript
深入浅析JavaScript中的in关键字和for-in循环
2020/04/20 Javascript
Node.js API详解之 readline模块用法详解
2020/05/22 Javascript
vue项目使用$router.go(-1)返回时刷新原来的界面操作
2020/07/26 Javascript
Python Xml文件添加字节属性的方法
2018/03/31 Python
python2.6.6如何升级到python2.7.14
2018/04/08 Python
python使用turtle库与random库绘制雪花
2018/06/22 Python
如何使用Python实现自动化水军评论
2019/06/26 Python
使用Python轻松完成垃圾分类(基于图像识别)
2019/07/09 Python
如何基于Python获取图片的物理尺寸
2019/11/25 Python
Python+appium框架原生代码实现App自动化测试详解
2020/03/06 Python
python 识别登录验证码图片功能的实现代码(完整代码)
2020/07/03 Python
python 多线程死锁问题的解决方案
2020/08/25 Python
欧洲第一中国智能手机和平板电脑网上商店:CECT-SHOP
2018/01/08 全球购物
菲律宾领先的在线时尚商店:Zalora菲律宾
2018/02/08 全球购物
经济学人订阅:The Economist
2018/07/19 全球购物
《盘古开天地》教学反思
2014/02/28 职场文书
就业协议书
2014/09/12 职场文书
村委会贫困证明范文
2014/09/21 职场文书
培养联系人考察意见
2015/06/01 职场文书
React配置子路由的实现
2021/06/03 Javascript