python实现感知机模型的示例


Posted in Python onSeptember 30, 2020
from sklearn.linear_model import Perceptron
import argparse #一个好用的参数传递模型
import numpy as np
from sklearn.datasets import load_iris #数据集
from sklearn.model_selection import train_test_split #训练集和测试集分割
from loguru import logger #日志输出,不清楚用法

#python is also oop 
class PerceptronToby():
  """
  n_epoch:迭代次数
  learning_rate:学习率
  loss_tolerance:损失阈值,即损失函数达到极小值的变化量
  """
  def __init__(self, n_epoch = 500, learning_rate = 0.1, loss_tolerance = 0.01):
    self._n_epoch = n_epoch
    self._lr = learning_rate
    self._loss_tolerance = loss_tolerance
  
  """训练模型,即找到每个数据最合适的权重以得到最小的损失函数"""
  def fit(self, X, y):
    # X:训练集,即数据集,每一行是样本,每一列是数据或标签,一样本包括一数据和一标签
    # y:标签,即1或-1
    n_sample, n_feature = X.shape #剥离矩阵的方法真帅

    #均匀初始化参数
    rnd_val = 1/np.sqrt(n_feature)
    rng = np.random.default_rng()
    self._w = rng.uniform(-rnd_val,rnd_val,size = n_feature)
    #偏置初始化为0
    self._b = 0

    #开始训练了,迭代n_epoch次
    num_epoch = 0 #记录迭代次数
    prev_loss = 0 #前损失值
    while True:
      curr_loss = 0 #现在损失值
      wrong_classify = 0 #误分类样本

      #一次迭代对每个样本操作一次
      for i in range(n_sample):
        #输出函数
        y_pred = np.dot(self._w,X[i]) + self._b
        #损失函数
        curr_loss += -y[i] * y_pred
        # 感知机只对误分类样本进行参数更新,使用梯度下降法
        if y[i] * y_pred <= 0:
          self._w += self._lr * y[i] * X[i]
          self._b += self._lr * y[i]
          wrong_classify += 1

      num_epoch += 1
      loss_diff = curr_loss - prev_loss
      prev_loss = curr_loss
      # 训练终止条件:
      # 1. 训练epoch数达到指定的epoch数时停止训练
      # 2. 本epoch损失与上一个epoch损失差异小于指定的阈值时停止训练
      # 3. 训练过程中不再存在误分类点时停止训练
      if num_epoch >= self._n_epoch or abs(loss_diff) < self._loss_tolerance or wrong_classify == 0:
        break


  """预测模型,顾名思义"""
  def predict(self, x):
    """给定输入样本,预测其类别"""
    y_pred = np.dot(self._w, x) + self._b
    return 1 if y_pred >= 0 else -1

#主函数
def main():
  #参数数组生成
  parser = argparse.ArgumentParser(description="感知机算法实现命令行参数")
  parser.add_argument("--nepoch", type=int, default=500, help="训练多少个epoch后终止训练")
  parser.add_argument("--lr", type=float, default=0.1, help="学习率")
  parser.add_argument("--loss_tolerance", type=float, default=0.001, help="当前损失与上一个epoch损失之差的绝对值小于该值时终止训练")
  args = parser.parse_args()
  #导入数据
  X, y = load_iris(return_X_y=True)
  # print(y)
  y[:50] = -1
  # 分割数据
  xtrain, xtest, ytrain, ytest = train_test_split(X[:100], y[:100], train_size=0.8, shuffle=True)
  # print(xtest)
  #调用并训练模型
  model = PerceptronToby(args.nepoch, args.lr, args.loss_tolerance)
  model.fit(xtrain, ytrain)

  n_test = xtest.shape[0]
  # print(n_test)
  n_right = 0
  for i in range(n_test):
    y_pred = model.predict(xtest[i])
    if y_pred == ytest[i]:
      n_right += 1
    else:
      logger.info("该样本真实标签为:{},但是toby模型预测标签为:{}".format(ytest[i], y_pred))
  logger.info("toby模型在测试集上的准确率为:{}%".format(n_right * 100 / n_test))

  skmodel = Perceptron(max_iter=args.nepoch)
  skmodel.fit(xtrain, ytrain)
  logger.info("sklearn模型在测试集上准确率为:{}%".format(100 * skmodel.score(xtest, ytest)))
if __name__ == "__main__":
  main()```

视频参考地址

以上就是python实现感知机模型的示例的详细内容,更多关于python 实现感知机模型的示例代码的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
在Python的Django框架中创建语言文件
Jul 27 Python
Python的净值数据接口调用示例分享
Mar 15 Python
Python的SQLalchemy模块连接与操作MySQL的基础示例
Jul 11 Python
python实现Virginia无密钥解密
Mar 20 Python
深度辨析Python的eval()与exec()的方法
Mar 26 Python
python的内存管理和垃圾回收机制详解
May 18 Python
深入浅析python的第三方库pandas
Feb 13 Python
Django 返回json数据的实现示例
Mar 05 Python
python IDLE添加行号显示教程
Apr 25 Python
Python使用pyexecjs代码案例解析
Jul 13 Python
Python如何在单元测试中给对象打补丁
Aug 03 Python
用Python爬虫破解滑动验证码的案例解析
May 06 Python
python 实现关联规则算法Apriori的示例
Sep 30 #Python
Python之字典添加元素的几种方法
Sep 30 #Python
Python之字典对象的几种创建方法
Sep 30 #Python
python 实现朴素贝叶斯算法的示例
Sep 30 #Python
Python字典取键、值对的方法步骤
Sep 30 #Python
Python根据字典的值查询出对应的键的方法
Sep 30 #Python
python字典通过值反查键的实现(简洁写法)
Sep 30 #Python
You might like
PHP 得到根目录的 __FILE__ 常量
2008/07/23 PHP
理解php原理的opcodes(操作码)
2010/10/26 PHP
PHP+Memcache实现wordpress访问总数统计(非插件)
2014/07/04 PHP
thinkphp特殊标签用法概述
2014/11/24 PHP
PHP设置进度条的方法
2015/07/08 PHP
Thinkphp无限级分类代码
2015/11/11 PHP
ThinkPHP中order()使用方法详解
2016/04/19 PHP
php实现有序数组旋转后寻找最小值方法
2018/09/27 PHP
php源码的使用方法讲解
2019/09/26 PHP
tp5.1 框架路由操作-URL生成实例分析
2020/05/26 PHP
Mootools 1.2教程 设置和获取样式表属性
2009/09/15 Javascript
extJs 常用到的增,删,改,查操作代码
2009/12/28 Javascript
jQuery+ajax实现动态执行脚本的方法
2015/01/27 Javascript
jQuery侧边栏实现代码
2016/05/06 Javascript
JavaScript中有关一个数组中最大值和最小值及它们的下表的输出的解决办法
2016/07/01 Javascript
javascript中的 object 和 function小结
2016/08/14 Javascript
web前端vue之vuex单独一文件使用方式实例详解
2018/01/11 Javascript
微信小程序实现自定义modal弹窗封装的方法
2018/06/15 Javascript
vue-cli 2.*中导入公共less文件的方法步骤
2018/11/22 Javascript
详解Vue 全局变量,局部变量
2019/04/17 Javascript
nodejs 递归拷贝、读取目录下所有文件和目录
2019/07/18 NodeJs
es6函数之严格模式用法实例分析
2020/03/17 Javascript
《javascript设计模式》学习笔记五:Javascript面向对象程序设计工厂模式实例分析
2020/04/08 Javascript
VUE页面中通过双击实现复制表格中内容的示例代码
2020/06/11 Javascript
Python的内存泄漏及gc模块的使用分析
2014/07/16 Python
Python中的asyncio代码详解
2019/06/10 Python
浅谈numpy中np.array()与np.asarray的区别以及.tolist
2020/06/03 Python
HTML5 Canvas锯齿图代码实例
2014/04/10 HTML / CSS
计算机大学生职业生涯规划书范文
2014/02/19 职场文书
征兵宣传标语
2014/06/20 职场文书
晚自修旷课检讨书怎么写
2014/11/17 职场文书
团代会闭幕词
2015/01/28 职场文书
银行实习推荐信
2015/03/27 职场文书
2016年社区服务活动总结
2016/04/06 职场文书
vue使用v-model进行跨组件绑定的基本实现方法
2021/04/28 Vue.js
win server2012 r2服务器共享文件夹如何设置
2022/06/21 Servers