感知器基础原理及python实现过程详解


Posted in Python onSeptember 30, 2019

简单版本,按照李航的《统计学习方法》的思路编写

感知器基础原理及python实现过程详解

数据采用了著名的sklearn自带的iries数据,最优化求解采用了SGD算法。

预处理增加了标准化操作。

'''
perceptron classifier

created on 2019.9.14
author: vince
'''
import pandas 
import numpy 
import logging
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

'''
perceptron classifier

Attributes
w: ld-array = weights after training
l: list = number of misclassification during each iteration 
'''
class Perceptron:
  def __init__(self, eta = 0.01, iter_num = 50, batch_size = 1):
    '''
    eta: float = learning rate (between 0.0 and 1.0).
    iter_num: int = iteration over the training dataset.
    batch_size: int = gradient descent batch number, 
      if batch_size == 1, used SGD; 
      if batch_size == 0, use BGD; 
      else MBGD;
    '''

    self.eta = eta;
    self.iter_num = iter_num;
    self.batch_size = batch_size;

  def train(self, X, Y):
    '''
    train training data.
    X:{array-like}, shape=[n_samples, n_features] = Training vectors, 
      where n_samples is the number of training samples and 
      n_features is the number of features.
    Y:{array-like}, share=[n_samples] = traget values.
    '''
    self.w = numpy.zeros(1 + X.shape[1]);
    self.l = numpy.zeros(self.iter_num);
    for iter_index in range(self.iter_num):
      for sample_index in range(X.shape[0]): 
        if (self.activation(X[sample_index]) != Y[sample_index]):
          logging.debug("%s: pred(%s), label(%s), %s, %s" % (sample_index, 
            self.net_input(X[sample_index]) , Y[sample_index],
            X[sample_index, 0], X[sample_index, 1]));
          self.l[iter_index] += 1;
      for sample_index in range(X.shape[0]): 
        if (self.activation(X[sample_index]) != Y[sample_index]):
          self.w[0] += self.eta * Y[sample_index];
          self.w[1:] += self.eta * numpy.dot(X[sample_index], Y[sample_index]);
          break;
      logging.info("iter %s: %s, %s, %s, %s" %
          (iter_index, self.w[0], self.w[1], self.w[2], self.l[iter_index]));

  def activation(self, x):
    return numpy.where(self.net_input(x) >= 0.0 , 1 , -1);

  def net_input(self, x): 
    return numpy.dot(x, self.w[1:]) + self.w[0];

  def predict(self, x):
    return self.activation(x);

def main():
  logging.basicConfig(level = logging.INFO,
      format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
      datefmt = '%a, %d %b %Y %H:%M:%S');

  iris = load_iris();

  features = iris.data[:99, [0, 2]];
  # normalization
  features_std = numpy.copy(features);
  for i in range(features.shape[1]):
    features_std[:, i] = (features_std[:, i] - features[:, i].mean()) / features[:, i].std();

  labels = numpy.where(iris.target[:99] == 0, -1, 1);

  # 2/3 data from training, 1/3 data for testing
  train_features, test_features, train_labels, test_labels = train_test_split(
      features_std, labels, test_size = 0.33, random_state = 23323);
  
  logging.info("train set shape:%s" % (str(train_features.shape)));

  p = Perceptron();

  p.train(train_features, train_labels);
    
  test_predict = numpy.array([]);
  for feature in test_features:
    predict_label = p.predict(feature);
    test_predict = numpy.append(test_predict, predict_label);

  score = accuracy_score(test_labels, test_predict);
  logging.info("The accruacy score is: %s "% (str(score)));

  #plot
  x_min, x_max = train_features[:, 0].min() - 1, train_features[:, 0].max() + 1;
  y_min, y_max = train_features[:, 1].min() - 1, train_features[:, 1].max() + 1;
  plt.xlim(x_min, x_max);
  plt.ylim(y_min, y_max);
  plt.xlabel("width");
  plt.ylabel("heigt");

  plt.scatter(train_features[:, 0], train_features[:, 1], c = train_labels, marker = 'o', s = 10);

  k = - p.w[1] / p.w[2];
  d = - p.w[0] / p.w[2];

  plt.plot([x_min, x_max], [k * x_min + d, k * x_max + d], "go-");

  plt.show();
  

if __name__ == "__main__":
  main();

感知器基础原理及python实现过程详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例详解Python中的split()函数的使用方法
Apr 07 Python
Python使用cookielib模块操作cookie的实例教程
Jul 12 Python
python 异常处理总结
Oct 18 Python
python简单图片操作:打开\显示\保存图像方法介绍
Nov 23 Python
python实现图片批量压缩程序
Jul 23 Python
python七夕浪漫表白源码
Apr 05 Python
详解python中递归函数
Apr 16 Python
python图像和办公文档处理总结
May 28 Python
Python几种常见算法汇总
Jun 02 Python
Keras构建神经网络踩坑(解决model.predict预测值全为0.0的问题)
Jul 07 Python
python raise的基本使用
Sep 10 Python
python中requests库+xpath+lxml简单使用
Apr 29 Python
基于python的BP神经网络及异或实现过程解析
Sep 30 #Python
Window10下python3.7 安装与卸载教程图解
Sep 30 #Python
Python检查图片是否损坏及图片类型是否正确过程详解
Sep 30 #Python
Python3 合并二叉树的实现
Sep 30 #Python
自适应线性神经网络Adaline的python实现详解
Sep 30 #Python
softmax及python实现过程解析
Sep 30 #Python
python根据时间获取周数代码实例
Sep 30 #Python
You might like
php fsockopen中多线程问题的解决办法[翻译]
2011/11/09 PHP
php中return的用法实例分析
2015/02/28 PHP
PHP实现过滤各种HTML标签
2015/05/17 PHP
php抓取并保存网站图片的实现代码
2015/10/28 PHP
RSA实现JS前端加密与PHP后端解密功能示例
2019/08/05 PHP
PHP实现页面静态化深入讲解
2021/03/04 PHP
一个选择最快的服务器转向代码
2009/04/27 Javascript
js实现运动logo图片效果及运动元素对象sportBox使用方法
2012/12/25 Javascript
利用jq让你的div居中的好方法分享
2013/11/21 Javascript
jQuery添加和删除指定标签的方法
2015/12/16 Javascript
微信小程序 生命周期详解
2016/10/12 Javascript
基于jQuery实现的单行公告活动轮播效果
2017/08/23 jQuery
全新打包工具parcel零配置vue开发脚手架
2018/01/11 Javascript
修改Nodejs内置的npm默认配置路径方法
2018/05/13 NodeJs
JS实现移动端点击按钮复制文本内容
2019/07/28 Javascript
微信小程序收藏功能的实现代码
2020/06/19 Javascript
nuxt.js服务端渲染中axios和proxy代理的配置操作
2020/11/06 Javascript
JavaScript实现表单验证功能
2020/12/09 Javascript
python实现识别相似图片小结
2016/02/22 Python
Python学习小技巧之利用字典的默认行为
2017/05/20 Python
Python多进程multiprocessing用法实例分析
2017/08/18 Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
2017/12/14 Python
python远程连接服务器MySQL数据库
2018/07/02 Python
windows10下安装TensorFlow Object Detection API的步骤
2019/06/13 Python
pandas的resample重采样的使用
2020/04/24 Python
Python 找出英文单词列表(list)中最长单词链
2020/12/14 Python
CSS3 特效范例整理
2011/08/22 HTML / CSS
美国儿童运动鞋和服装零售商:Kids Foot Locker
2017/08/05 全球购物
班班通项目实施方案
2014/02/25 职场文书
管理标语大全
2014/06/24 职场文书
节能环保家庭事迹材料
2014/08/27 职场文书
公务员四风问题对照检查材料整改措施
2014/09/26 职场文书
销售经理工作失职检讨书
2014/10/24 职场文书
十八大观后感
2015/06/12 职场文书
PostgreSQL数据库去除重复数据和运算符的基本查询操作
2022/04/12 PostgreSQL
Android 中的类文件和类加载器详情
2022/06/05 Java/Android