Python实现的NN神经网络算法完整示例


Posted in Python onJune 19, 2018

本文实例讲述了Python实现的NN神经网络算法。分享给大家供大家参考,具体如下:

参考自Github开源代码:https://github.com/dennybritz/nn-from-scratch

运行环境

  • Pyhton3
  • numpy(科学计算包)
  • matplotlib(画图所需,不画图可不必)
  • sklearn(人工智能包,生成数据使用)

计算过程

Python实现的NN神经网络算法完整示例

输入样例

none

代码实现

# -*- coding:utf-8 -*-
#!python3
__author__ = 'Wsine'
import numpy as np
import sklearn
import sklearn.datasets
import sklearn.linear_model
import matplotlib.pyplot as plt
import matplotlib
import operator
import time
def createData(dim=200, cnoise=0.20):
  """
  输出:数据集, 对应的类别标签
  描述:生成一个数据集和对应的类别标签
  """
  np.random.seed(0)
  X, y = sklearn.datasets.make_moons(dim, noise=cnoise)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
  #plt.show()
  return X, y
def plot_decision_boundary(pred_func, X, y):
  """
  输入:边界函数, 数据集, 类别标签
  描述:绘制决策边界(画图用)
  """
  # 设置最小最大值, 加上一点外边界
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # 根据最小最大值和一个网格距离生成整个网格
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # 对整个网格预测边界值
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # 绘制边界和数据集的点
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)
def calculate_loss(model, X, y):
  """
  输入:训练模型, 数据集, 类别标签
  输出:误判的概率
  描述:计算整个模型的性能
  """
  W1, b1, W2, b2 = model['W1'], model['b1'], model['W2'], model['b2']
  # 正向传播来计算预测的分类值
  z1 = X.dot(W1) + b1
  a1 = np.tanh(z1)
  z2 = a1.dot(W2) + b2
  exp_scores = np.exp(z2)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
  # 计算误判概率
  corect_logprobs = -np.log(probs[range(num_examples), y])
  data_loss = np.sum(corect_logprobs)
  # 加入正则项修正错误(可选)
  data_loss += reg_lambda/2 * (np.sum(np.square(W1)) + np.sum(np.square(W2)))
  return 1./num_examples * data_loss
def predict(model, x):
  """
  输入:训练模型, 预测向量
  输出:判决类别
  描述:预测类别属于(0 or 1)
  """
  W1, b1, W2, b2 = model['W1'], model['b1'], model['W2'], model['b2']
  # 正向传播计算
  z1 = x.dot(W1) + b1
  a1 = np.tanh(z1)
  z2 = a1.dot(W2) + b2
  exp_scores = np.exp(z2)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
  return np.argmax(probs, axis=1)
def initParameter(X):
  """
  输入:数据集
  描述:初始化神经网络算法的参数
     必须初始化为全局函数!
     这里需要手动设置!
  """
  global num_examples
  num_examples = len(X) # 训练集的大小
  global nn_input_dim
  nn_input_dim = 2 # 输入层维数
  global nn_output_dim
  nn_output_dim = 2 # 输出层维数
  # 梯度下降参数
  global epsilon
  epsilon = 0.01 # 梯度下降学习步长
  global reg_lambda
  reg_lambda = 0.01 # 修正的指数
def build_model(X, y, nn_hdim, num_passes=20000, print_loss=False):
  """
  输入:数据集, 类别标签, 隐藏层层数, 迭代次数, 是否输出误判率
  输出:神经网络模型
  描述:生成一个指定层数的神经网络模型
  """
  # 根据维度随机初始化参数
  np.random.seed(0)
  W1 = np.random.randn(nn_input_dim, nn_hdim) / np.sqrt(nn_input_dim)
  b1 = np.zeros((1, nn_hdim))
  W2 = np.random.randn(nn_hdim, nn_output_dim) / np.sqrt(nn_hdim)
  b2 = np.zeros((1, nn_output_dim))
  model = {}
  # 梯度下降
  for i in range(0, num_passes):
    # 正向传播
    z1 = X.dot(W1) + b1
    a1 = np.tanh(z1) # 激活函数使用tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
    z2 = a1.dot(W2) + b2
    exp_scores = np.exp(z2) # 原始归一化
    probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
    # 后向传播
    delta3 = probs
    delta3[range(num_examples), y] -= 1
    dW2 = (a1.T).dot(delta3)
    db2 = np.sum(delta3, axis=0, keepdims=True)
    delta2 = delta3.dot(W2.T) * (1 - np.power(a1, 2))
    dW1 = np.dot(X.T, delta2)
    db1 = np.sum(delta2, axis=0)
    # 加入修正项
    dW2 += reg_lambda * W2
    dW1 += reg_lambda * W1
    # 更新梯度下降参数
    W1 += -epsilon * dW1
    b1 += -epsilon * db1
    W2 += -epsilon * dW2
    b2 += -epsilon * db2
    # 更新模型
    model = { 'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}
    # 一定迭代次数后输出当前误判率
    if print_loss and i % 1000 == 0:
      print("Loss after iteration %i: %f" % (i, calculate_loss(model, X, y)))
  plot_decision_boundary(lambda x: predict(model, x), X, y)
  plt.title("Decision Boundary for hidden layer size %d" % nn_hdim)
  #plt.show()
  return model
def main():
  dataSet, labels = createData(200, 0.20)
  initParameter(dataSet)
  nnModel = build_model(dataSet, labels, 3, print_loss=False)
  print("Loss is %f" % calculate_loss(nnModel, dataSet, labels))
if __name__ == '__main__':
  start = time.clock()
  main()
  end = time.clock()
  print('finish all in %s' % str(end - start))
  plt.show()

输出样例

Loss is 0.071316
finish all in 7.221354361552228

Python实现的NN神经网络算法完整示例

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python网络编程学习笔记(一)
Jun 09 Python
深入解析Python中的WSGI接口
May 11 Python
Python的Scrapy爬虫框架简单学习笔记
Jan 20 Python
Python 闭包的使用方法
Sep 07 Python
Python 和 JS 有哪些相同之处
Nov 23 Python
python实现批量修改图片格式和尺寸
Jun 07 Python
python查看模块,对象的函数方法
Oct 16 Python
在pycharm上mongodb配置及可视化设置方法
Nov 30 Python
在Python中将函数作为另一个函数的参数传入并调用的方法
Jan 22 Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 Python
为什么称python为胶水语言
Jun 16 Python
Pandas 数据编码的十种方法
Apr 20 Python
python中的二维列表实例详解
Jun 19 #Python
Tensorflow中使用tfrecord方式读取数据的方法
Jun 19 #Python
python3实现SMTP发送邮件详细教程
Jun 19 #Python
Python SVM(支持向量机)实现方法完整示例
Jun 19 #Python
Tensorflow使用tfrecord输入数据格式
Jun 19 #Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 #Python
python如何爬取个性签名
Jun 19 #Python
You might like
PHP 高级课程笔记 面向对象
2009/06/21 PHP
php页面函数设置超时限制的方法
2014/12/01 PHP
浅谈PHP中pack、unpack的详细用法
2018/03/12 PHP
Javascript 代码也可以变得优美的实现方法
2009/06/22 Javascript
JavaScript中的onerror事件概述及使用
2013/04/01 Javascript
jquery通过a标签删除table中的一行的代码
2013/12/02 Javascript
Javascript和Java获取各种form表单信息的简单实例
2014/02/14 Javascript
Javascript实现苹果悬浮虚拟按钮
2016/04/10 Javascript
Node.js的基本知识简单汇总
2016/09/19 Javascript
炫酷的js手风琴效果
2016/10/13 Javascript
基于JavaScript实现轮播图原理及示例
2020/04/10 Javascript
vue.js实现用户评论、登录、注册、及修改信息功能
2020/05/30 Javascript
JS 中使用Promise 实现红绿灯实例代码(demo)
2017/10/20 Javascript
基于vue-video-player自定义播放器的方法
2018/03/21 Javascript
vue实现简单的MVVM框架
2018/08/05 Javascript
Vue中多元素过渡特效的解决方案
2020/02/05 Javascript
深入理解Antd-Select组件的用法
2020/02/25 Javascript
vue 项目中当访问路由不存在的时候默认访问404页面操作
2020/08/31 Javascript
vue项目中播放rtmp视频文件流的方法
2020/09/17 Javascript
pycharm 使用心得(六)进行简单的数据库管理
2014/06/06 Python
Python实现爬取知乎神回复简单爬虫代码分享
2015/01/04 Python
Python创建xml的方法
2015/03/10 Python
Python浅拷贝与深拷贝用法实例
2015/05/09 Python
Python编程之gui程序实现简单文件浏览器代码
2017/12/08 Python
django ajax json的实例代码
2018/05/29 Python
pandas使用get_dummies进行one-hot编码的方法
2018/07/10 Python
python实现拉普拉斯特征图降维示例
2019/11/25 Python
python 求10个数的平均数实例
2019/12/16 Python
Python基于Webhook实现github自动化部署
2020/11/28 Python
美的官方商城:Midea
2016/09/14 全球购物
The North Face北面荷兰官网:美国著名户外品牌
2019/10/16 全球购物
PHP如何调用MYSQL存储过程
2014/05/30 面试题
应聘医药代表职位求职信
2013/10/21 职场文书
2016年春节问候语
2015/11/11 职场文书
2016年幼儿园庆六一开幕词
2016/03/04 职场文书
nginx刷新页面出现404解决方案(亲测有效)
2022/03/18 Servers