感知器基础原理及python实现过程详解


Posted in Python onSeptember 30, 2019

简单版本,按照李航的《统计学习方法》的思路编写

感知器基础原理及python实现过程详解

数据采用了著名的sklearn自带的iries数据,最优化求解采用了SGD算法。

预处理增加了标准化操作。

'''
perceptron classifier

created on 2019.9.14
author: vince
'''
import pandas 
import numpy 
import logging
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

'''
perceptron classifier

Attributes
w: ld-array = weights after training
l: list = number of misclassification during each iteration 
'''
class Perceptron:
  def __init__(self, eta = 0.01, iter_num = 50, batch_size = 1):
    '''
    eta: float = learning rate (between 0.0 and 1.0).
    iter_num: int = iteration over the training dataset.
    batch_size: int = gradient descent batch number, 
      if batch_size == 1, used SGD; 
      if batch_size == 0, use BGD; 
      else MBGD;
    '''

    self.eta = eta;
    self.iter_num = iter_num;
    self.batch_size = batch_size;

  def train(self, X, Y):
    '''
    train training data.
    X:{array-like}, shape=[n_samples, n_features] = Training vectors, 
      where n_samples is the number of training samples and 
      n_features is the number of features.
    Y:{array-like}, share=[n_samples] = traget values.
    '''
    self.w = numpy.zeros(1 + X.shape[1]);
    self.l = numpy.zeros(self.iter_num);
    for iter_index in range(self.iter_num):
      for sample_index in range(X.shape[0]): 
        if (self.activation(X[sample_index]) != Y[sample_index]):
          logging.debug("%s: pred(%s), label(%s), %s, %s" % (sample_index, 
            self.net_input(X[sample_index]) , Y[sample_index],
            X[sample_index, 0], X[sample_index, 1]));
          self.l[iter_index] += 1;
      for sample_index in range(X.shape[0]): 
        if (self.activation(X[sample_index]) != Y[sample_index]):
          self.w[0] += self.eta * Y[sample_index];
          self.w[1:] += self.eta * numpy.dot(X[sample_index], Y[sample_index]);
          break;
      logging.info("iter %s: %s, %s, %s, %s" %
          (iter_index, self.w[0], self.w[1], self.w[2], self.l[iter_index]));

  def activation(self, x):
    return numpy.where(self.net_input(x) >= 0.0 , 1 , -1);

  def net_input(self, x): 
    return numpy.dot(x, self.w[1:]) + self.w[0];

  def predict(self, x):
    return self.activation(x);

def main():
  logging.basicConfig(level = logging.INFO,
      format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
      datefmt = '%a, %d %b %Y %H:%M:%S');

  iris = load_iris();

  features = iris.data[:99, [0, 2]];
  # normalization
  features_std = numpy.copy(features);
  for i in range(features.shape[1]):
    features_std[:, i] = (features_std[:, i] - features[:, i].mean()) / features[:, i].std();

  labels = numpy.where(iris.target[:99] == 0, -1, 1);

  # 2/3 data from training, 1/3 data for testing
  train_features, test_features, train_labels, test_labels = train_test_split(
      features_std, labels, test_size = 0.33, random_state = 23323);
  
  logging.info("train set shape:%s" % (str(train_features.shape)));

  p = Perceptron();

  p.train(train_features, train_labels);
    
  test_predict = numpy.array([]);
  for feature in test_features:
    predict_label = p.predict(feature);
    test_predict = numpy.append(test_predict, predict_label);

  score = accuracy_score(test_labels, test_predict);
  logging.info("The accruacy score is: %s "% (str(score)));

  #plot
  x_min, x_max = train_features[:, 0].min() - 1, train_features[:, 0].max() + 1;
  y_min, y_max = train_features[:, 1].min() - 1, train_features[:, 1].max() + 1;
  plt.xlim(x_min, x_max);
  plt.ylim(y_min, y_max);
  plt.xlabel("width");
  plt.ylabel("heigt");

  plt.scatter(train_features[:, 0], train_features[:, 1], c = train_labels, marker = 'o', s = 10);

  k = - p.w[1] / p.w[2];
  d = - p.w[0] / p.w[2];

  plt.plot([x_min, x_max], [k * x_min + d, k * x_max + d], "go-");

  plt.show();
  

if __name__ == "__main__":
  main();

感知器基础原理及python实现过程详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现插入排序算法
Jun 05 Python
python简单程序读取串口信息的方法
Mar 13 Python
python实现定时播放mp3
Mar 29 Python
python针对excel的操作技巧
Mar 13 Python
对Python 网络设备巡检脚本的实例讲解
Apr 22 Python
Python实现的远程登录windows系统功能示例
Jun 21 Python
python format 格式化输出方法
Jul 16 Python
python中使用ctypes调用so传参设置遇到的问题及解决方法
Jun 19 Python
python实现动态创建类的方法分析
Jun 25 Python
浅析pip安装第三方库及pycharm中导入第三方库的问题
Mar 10 Python
jupyter 中文乱码设置编码格式 避免控制台输出的解决
Apr 20 Python
Django分页器的用法你都了解吗
May 26 Python
基于python的BP神经网络及异或实现过程解析
Sep 30 #Python
Window10下python3.7 安装与卸载教程图解
Sep 30 #Python
Python检查图片是否损坏及图片类型是否正确过程详解
Sep 30 #Python
Python3 合并二叉树的实现
Sep 30 #Python
自适应线性神经网络Adaline的python实现详解
Sep 30 #Python
softmax及python实现过程解析
Sep 30 #Python
python根据时间获取周数代码实例
Sep 30 #Python
You might like
php生成excel文件的简单方法
2014/02/08 PHP
thinkphp实现发送邮件密码找回功能实例
2014/12/01 PHP
详解WordPress中用于合成数组的wp_parse_args()函数
2015/12/18 PHP
jQuery 1.0.4 - New Wave Javascript(js源文件)
2007/01/15 Javascript
js动画效果制件让图片组成动画代码分享
2014/01/14 Javascript
JavaScript正则表达式匹配 div  style标签
2016/03/15 Javascript
为什么JavaScript没有块级作用域
2016/05/22 Javascript
jQuery layui常用方法介绍
2016/07/25 Javascript
最好用的Bootstrap fileinput.js文件上传组件
2016/12/12 Javascript
angular实现表单验证及提交功能
2017/02/01 Javascript
jQuery时间验证和转换为标准格式的时间格式
2017/03/06 Javascript
将angular.js项目整合到.net mvc中的方法详解
2017/06/29 Javascript
JS+canvas画一个圆锥实例代码
2017/12/13 Javascript
前端MVVM框架解析之双向绑定
2018/01/24 Javascript
vue2.0 循环遍历加载不同图片的方法
2018/03/06 Javascript
Vue函数式组件的应用实例详解
2019/08/30 Javascript
vue中echarts的用法及与elementui-select的协同绑定操作
2020/11/17 Vue.js
[03:38]2014DOTA2西雅图国际邀请赛 VG战队巡礼
2014/07/07 DOTA
[00:49]完美世界DOTA2联赛10月28日开团时刻:随便打
2020/10/29 DOTA
Python和GO语言实现的消息摘要算法示例
2015/03/10 Python
详解python之配置日志的几种方式
2017/05/22 Python
python使用装饰器作日志处理的方法
2019/07/11 Python
jupyter notebook 使用过程中python莫名崩溃的原因及解决方式
2020/04/10 Python
使用OpenCV获取图像某点的颜色值,并设置某点的颜色
2020/06/02 Python
Pycharm编辑器功能之代码折叠效果的实现代码
2020/10/15 Python
Nike西班牙官方网站:Nike.com (ES)
2017/10/30 全球购物
美国家居装饰网上商店:Lulu & Georgia
2019/09/14 全球购物
美国户外服装和装备购物网站:Outland USA
2020/03/22 全球购物
全球最受追捧的运动服品牌领先数字目的地:Stylerunner
2020/11/25 全球购物
赔偿协议书范本
2014/04/15 职场文书
公司募捐倡议书
2014/05/14 职场文书
运动会报道稿300字
2014/10/02 职场文书
中学生检讨书范文
2014/11/03 职场文书
2015年乡镇安全生产工作总结
2015/05/19 职场文书
Mysql服务添加 iptables防火墙策略的方案
2021/04/29 MySQL
Python中npy和mat文件的保存与读取
2022/04/24 Python