python实现感知机模型的示例


Posted in Python onSeptember 30, 2020
from sklearn.linear_model import Perceptron
import argparse #一个好用的参数传递模型
import numpy as np
from sklearn.datasets import load_iris #数据集
from sklearn.model_selection import train_test_split #训练集和测试集分割
from loguru import logger #日志输出,不清楚用法

#python is also oop 
class PerceptronToby():
  """
  n_epoch:迭代次数
  learning_rate:学习率
  loss_tolerance:损失阈值,即损失函数达到极小值的变化量
  """
  def __init__(self, n_epoch = 500, learning_rate = 0.1, loss_tolerance = 0.01):
    self._n_epoch = n_epoch
    self._lr = learning_rate
    self._loss_tolerance = loss_tolerance
  
  """训练模型,即找到每个数据最合适的权重以得到最小的损失函数"""
  def fit(self, X, y):
    # X:训练集,即数据集,每一行是样本,每一列是数据或标签,一样本包括一数据和一标签
    # y:标签,即1或-1
    n_sample, n_feature = X.shape #剥离矩阵的方法真帅

    #均匀初始化参数
    rnd_val = 1/np.sqrt(n_feature)
    rng = np.random.default_rng()
    self._w = rng.uniform(-rnd_val,rnd_val,size = n_feature)
    #偏置初始化为0
    self._b = 0

    #开始训练了,迭代n_epoch次
    num_epoch = 0 #记录迭代次数
    prev_loss = 0 #前损失值
    while True:
      curr_loss = 0 #现在损失值
      wrong_classify = 0 #误分类样本

      #一次迭代对每个样本操作一次
      for i in range(n_sample):
        #输出函数
        y_pred = np.dot(self._w,X[i]) + self._b
        #损失函数
        curr_loss += -y[i] * y_pred
        # 感知机只对误分类样本进行参数更新,使用梯度下降法
        if y[i] * y_pred <= 0:
          self._w += self._lr * y[i] * X[i]
          self._b += self._lr * y[i]
          wrong_classify += 1

      num_epoch += 1
      loss_diff = curr_loss - prev_loss
      prev_loss = curr_loss
      # 训练终止条件:
      # 1. 训练epoch数达到指定的epoch数时停止训练
      # 2. 本epoch损失与上一个epoch损失差异小于指定的阈值时停止训练
      # 3. 训练过程中不再存在误分类点时停止训练
      if num_epoch >= self._n_epoch or abs(loss_diff) < self._loss_tolerance or wrong_classify == 0:
        break


  """预测模型,顾名思义"""
  def predict(self, x):
    """给定输入样本,预测其类别"""
    y_pred = np.dot(self._w, x) + self._b
    return 1 if y_pred >= 0 else -1

#主函数
def main():
  #参数数组生成
  parser = argparse.ArgumentParser(description="感知机算法实现命令行参数")
  parser.add_argument("--nepoch", type=int, default=500, help="训练多少个epoch后终止训练")
  parser.add_argument("--lr", type=float, default=0.1, help="学习率")
  parser.add_argument("--loss_tolerance", type=float, default=0.001, help="当前损失与上一个epoch损失之差的绝对值小于该值时终止训练")
  args = parser.parse_args()
  #导入数据
  X, y = load_iris(return_X_y=True)
  # print(y)
  y[:50] = -1
  # 分割数据
  xtrain, xtest, ytrain, ytest = train_test_split(X[:100], y[:100], train_size=0.8, shuffle=True)
  # print(xtest)
  #调用并训练模型
  model = PerceptronToby(args.nepoch, args.lr, args.loss_tolerance)
  model.fit(xtrain, ytrain)

  n_test = xtest.shape[0]
  # print(n_test)
  n_right = 0
  for i in range(n_test):
    y_pred = model.predict(xtest[i])
    if y_pred == ytest[i]:
      n_right += 1
    else:
      logger.info("该样本真实标签为:{},但是toby模型预测标签为:{}".format(ytest[i], y_pred))
  logger.info("toby模型在测试集上的准确率为:{}%".format(n_right * 100 / n_test))

  skmodel = Perceptron(max_iter=args.nepoch)
  skmodel.fit(xtrain, ytrain)
  logger.info("sklearn模型在测试集上准确率为:{}%".format(100 * skmodel.score(xtest, ytest)))
if __name__ == "__main__":
  main()```

视频参考地址

以上就是python实现感知机模型的示例的详细内容,更多关于python 实现感知机模型的示例代码的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
简介Python设计模式中的代理模式与模板方法模式编程
Feb 02 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
tensorflow实现逻辑回归模型
Sep 08 Python
Pycharm更换python解释器的方法
Oct 29 Python
python 实现倒排索引的方法
Dec 25 Python
Python分布式进程中你会遇到的问题解析
May 28 Python
python虚拟环境完美部署教程
Aug 06 Python
python os.path.isfile()因参数问题判断错误的解决
Nov 29 Python
使用TFRecord存取多个数据案例
Feb 17 Python
在django admin中配置搜索域是一个外键时的处理方法
May 20 Python
Python filter()及reduce()函数使用方法解析
Sep 05 Python
详解Python牛顿插值法
May 11 Python
python 实现关联规则算法Apriori的示例
Sep 30 #Python
Python之字典添加元素的几种方法
Sep 30 #Python
Python之字典对象的几种创建方法
Sep 30 #Python
python 实现朴素贝叶斯算法的示例
Sep 30 #Python
Python字典取键、值对的方法步骤
Sep 30 #Python
Python根据字典的值查询出对应的键的方法
Sep 30 #Python
python字典通过值反查键的实现(简洁写法)
Sep 30 #Python
You might like
SONY ICF-F10中波修复记
2021/03/02 无线电
laravel中数据显示方法(默认值和下拉option默认选中)
2019/10/11 PHP
PHP数组实际占用内存大小原理解析
2020/12/11 PHP
jQuery学习笔记(4)--Jquery中获取table中某列值的具体思路
2013/04/10 Javascript
javascript实现促销倒计时+fixed固定在底部
2013/09/18 Javascript
Javascript 拖拽雏形(逐行分析代码,让你轻松了拖拽的原理)
2015/01/23 Javascript
javascript实现类似百度分享功能的方法
2015/07/27 Javascript
判断数组是否包含某个元素的js函数实现方法
2016/05/19 Javascript
EasyUI框架 使用Ajax提交注册信息的实现代码
2017/09/27 Javascript
JS获取当前地理位置的方法
2017/10/25 Javascript
微信小程序实现下拉框功能
2019/07/16 Javascript
微信小程序实现录制、试听、上传音频功能(带波形图)
2020/02/27 Javascript
解决vscode进行vue格式化,会自动补分号和双引号的问题
2020/10/26 Javascript
[01:04:14]OG vs Winstrike 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
Python实现FTP上传文件或文件夹实例(递归)
2017/01/16 Python
Python实现随机生成手机号及正则验证手机号的方法
2018/04/25 Python
Python格式化输出%s和%d
2018/05/07 Python
对python 数据处理中的LabelEncoder 和 OneHotEncoder详解
2018/07/11 Python
Python读取excel中的图片完美解决方法
2018/07/27 Python
python 字典中取值的两种方法小结
2018/08/02 Python
python实现换位加密算法的示例
2018/10/14 Python
破解安装Pycharm的方法
2018/10/19 Python
Python循环实现n的全排列功能
2019/09/16 Python
Python操作redis和mongoDB的方法
2019/12/19 Python
详细分析Python可变对象和不可变对象
2020/07/09 Python
Python-openpyxl表格读取写入的案例详解
2020/11/02 Python
美国百年历史早餐食品供应商:Wolferman’s
2017/01/18 全球购物
阿迪达斯德国官方网站:adidas德国
2017/07/12 全球购物
德国自行车商店:Tretwerk
2019/06/21 全球购物
小学数学国培感言
2014/03/10 职场文书
伦敦奥运会口号
2014/06/13 职场文书
2016天猫双十一广告语
2016/01/28 职场文书
授权协议书范本(3篇)
2019/10/15 职场文书
用python自动生成日历
2021/04/24 Python
使用@Value值注入及配置文件组件扫描
2021/07/09 Java/Android
Python创建SQL数据库流程逐步讲解
2022/09/23 Python