感知器基础原理及python实现过程详解


Posted in Python onSeptember 30, 2019

简单版本,按照李航的《统计学习方法》的思路编写

感知器基础原理及python实现过程详解

数据采用了著名的sklearn自带的iries数据,最优化求解采用了SGD算法。

预处理增加了标准化操作。

'''
perceptron classifier

created on 2019.9.14
author: vince
'''
import pandas 
import numpy 
import logging
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

'''
perceptron classifier

Attributes
w: ld-array = weights after training
l: list = number of misclassification during each iteration 
'''
class Perceptron:
  def __init__(self, eta = 0.01, iter_num = 50, batch_size = 1):
    '''
    eta: float = learning rate (between 0.0 and 1.0).
    iter_num: int = iteration over the training dataset.
    batch_size: int = gradient descent batch number, 
      if batch_size == 1, used SGD; 
      if batch_size == 0, use BGD; 
      else MBGD;
    '''

    self.eta = eta;
    self.iter_num = iter_num;
    self.batch_size = batch_size;

  def train(self, X, Y):
    '''
    train training data.
    X:{array-like}, shape=[n_samples, n_features] = Training vectors, 
      where n_samples is the number of training samples and 
      n_features is the number of features.
    Y:{array-like}, share=[n_samples] = traget values.
    '''
    self.w = numpy.zeros(1 + X.shape[1]);
    self.l = numpy.zeros(self.iter_num);
    for iter_index in range(self.iter_num):
      for sample_index in range(X.shape[0]): 
        if (self.activation(X[sample_index]) != Y[sample_index]):
          logging.debug("%s: pred(%s), label(%s), %s, %s" % (sample_index, 
            self.net_input(X[sample_index]) , Y[sample_index],
            X[sample_index, 0], X[sample_index, 1]));
          self.l[iter_index] += 1;
      for sample_index in range(X.shape[0]): 
        if (self.activation(X[sample_index]) != Y[sample_index]):
          self.w[0] += self.eta * Y[sample_index];
          self.w[1:] += self.eta * numpy.dot(X[sample_index], Y[sample_index]);
          break;
      logging.info("iter %s: %s, %s, %s, %s" %
          (iter_index, self.w[0], self.w[1], self.w[2], self.l[iter_index]));

  def activation(self, x):
    return numpy.where(self.net_input(x) >= 0.0 , 1 , -1);

  def net_input(self, x): 
    return numpy.dot(x, self.w[1:]) + self.w[0];

  def predict(self, x):
    return self.activation(x);

def main():
  logging.basicConfig(level = logging.INFO,
      format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
      datefmt = '%a, %d %b %Y %H:%M:%S');

  iris = load_iris();

  features = iris.data[:99, [0, 2]];
  # normalization
  features_std = numpy.copy(features);
  for i in range(features.shape[1]):
    features_std[:, i] = (features_std[:, i] - features[:, i].mean()) / features[:, i].std();

  labels = numpy.where(iris.target[:99] == 0, -1, 1);

  # 2/3 data from training, 1/3 data for testing
  train_features, test_features, train_labels, test_labels = train_test_split(
      features_std, labels, test_size = 0.33, random_state = 23323);
  
  logging.info("train set shape:%s" % (str(train_features.shape)));

  p = Perceptron();

  p.train(train_features, train_labels);
    
  test_predict = numpy.array([]);
  for feature in test_features:
    predict_label = p.predict(feature);
    test_predict = numpy.append(test_predict, predict_label);

  score = accuracy_score(test_labels, test_predict);
  logging.info("The accruacy score is: %s "% (str(score)));

  #plot
  x_min, x_max = train_features[:, 0].min() - 1, train_features[:, 0].max() + 1;
  y_min, y_max = train_features[:, 1].min() - 1, train_features[:, 1].max() + 1;
  plt.xlim(x_min, x_max);
  plt.ylim(y_min, y_max);
  plt.xlabel("width");
  plt.ylabel("heigt");

  plt.scatter(train_features[:, 0], train_features[:, 1], c = train_labels, marker = 'o', s = 10);

  k = - p.w[1] / p.w[2];
  d = - p.w[0] / p.w[2];

  plt.plot([x_min, x_max], [k * x_min + d, k * x_max + d], "go-");

  plt.show();
  

if __name__ == "__main__":
  main();

感知器基础原理及python实现过程详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
几个提升Python运行效率的方法之间的对比
Apr 03 Python
Python基于Socket实现的简单聊天程序示例
Aug 05 Python
Python 类的特殊成员解析
Jun 20 Python
python 列表降维的实例讲解
Jun 28 Python
python使用插值法画出平滑曲线
Dec 15 Python
pandas去重复行并分类汇总的实现方法
Jan 29 Python
django 2.2和mysql使用的常见问题
Jul 18 Python
Python上下文管理器类和上下文管理器装饰器contextmanager用法实例分析
Nov 07 Python
基于python cut和qcut的用法及区别详解
Nov 22 Python
Python模块 _winreg操作注册表
Feb 05 Python
python实现俄罗斯方块游戏(改进版)
Mar 13 Python
python用海龟绘图写贪吃蛇游戏
Jun 18 Python
基于python的BP神经网络及异或实现过程解析
Sep 30 #Python
Window10下python3.7 安装与卸载教程图解
Sep 30 #Python
Python检查图片是否损坏及图片类型是否正确过程详解
Sep 30 #Python
Python3 合并二叉树的实现
Sep 30 #Python
自适应线性神经网络Adaline的python实现详解
Sep 30 #Python
softmax及python实现过程解析
Sep 30 #Python
python根据时间获取周数代码实例
Sep 30 #Python
You might like
用mysql_fetch_array()获取当前行数据的方法详解
2013/06/05 PHP
AJAX的跨域访问-两种有效的解决方法介绍
2013/06/22 PHP
通过Email发送PHP错误的方法
2015/07/20 PHP
Yii2框架数据验证操作实例详解
2018/05/02 PHP
IE浏览器打印的页眉页脚设置解决方法
2009/12/08 Javascript
asp.net+jquery滚动滚动条加载数据的下拉控件
2010/06/25 Javascript
js对象的比较
2011/02/26 Javascript
用JS判别浏览器种类以及IE版本的几种方法小结
2011/08/02 Javascript
分享10个优化代码的CSS和JavaScript工具
2016/05/11 Javascript
静态页面html中跳转传值的JS处理技巧
2016/06/22 Javascript
JS 获取HTML标签内的子节点的方法
2016/09/21 Javascript
基于BootStrap栅格栏系统完成网站底部版权信息区
2016/12/23 Javascript
解析JavaScript实现DDoS攻击原理与保护措施
2016/12/26 Javascript
jQuery模拟实现天猫购物车动画效果实例代码
2017/05/25 jQuery
jQuery开源组件BootstrapValidator使用详解
2017/06/29 jQuery
详解如何提高 webpack 构建 Vue 项目的速度
2017/07/03 Javascript
vue axios数据请求get、post方法及实例详解
2018/09/11 Javascript
python+matplotlib实现鼠标移动三角形高亮及索引显示
2018/01/15 Python
Django学习之文件上传与下载
2019/10/06 Python
python实现的读取网页并分词功能示例
2019/10/29 Python
Python类的动态绑定实现原理
2020/03/21 Python
keras小技巧——获取某一个网络层的输出方式
2020/05/23 Python
Django DRF路由与扩展功能的实现
2020/06/03 Python
python 使用paramiko模块进行封装,远程操作linux主机的示例代码
2020/12/03 Python
鞋类设计与工艺专业销售求职信
2013/11/01 职场文书
ktv中秋节活动方案
2014/01/30 职场文书
大学课外活动总结
2014/07/09 职场文书
学习实践科学发展观心得体会
2014/09/10 职场文书
三行辞职书范文
2015/02/26 职场文书
小学重阳节活动总结
2015/03/24 职场文书
小学安全工作总结2015
2015/05/18 职场文书
超市啤酒狂欢夜策划方案范文!
2019/07/03 职场文书
解决Python字典查找报Keyerror的问题
2021/05/26 Python
OpenCV实现普通阈值
2021/11/17 Java/Android
js中Map和Set的用法及区别实例详解
2022/02/15 Javascript
vue3种table表格选项个数的控制方法
2022/04/14 Vue.js