python 牛顿法实现逻辑回归(Logistic Regression)


Posted in Python onOctober 15, 2020

本文采用的训练方法是牛顿法(Newton Method)。

代码

import numpy as np

class LogisticRegression(object):
 """
 Logistic Regression Classifier training by Newton Method
 """

 def __init__(self, error: float = 0.7, max_epoch: int = 100):
  """
  :param error: float, if the distance between new weight and 
      old weight is less than error, the process 
      of traing will break.
  :param max_epoch: if training epoch >= max_epoch the process 
       of traing will break.
  """
  self.error = error
  self.max_epoch = max_epoch
  self.weight = None
  self.sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)

 def p_func(self, X_):
  """Get P(y=1 | x)
  :param X_: shape = (n_samples + 1, n_features)
  :return: shape = (n_samples)
  """
  tmp = np.exp(self.weight @ X_.T)
  return tmp / (1 + tmp)

 def diff(self, X_, y, p):
  """Get derivative
  :param X_: shape = (n_samples, n_features + 1) 
  :param y: shape = (n_samples)
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1) first derivative
  """
  return -(y - p) @ X_

 def hess_mat(self, X_, p):
  """Get Hessian Matrix
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1, n_features + 1) second derivative
  """
  hess = np.zeros((X_.shape[1], X_.shape[1]))
  for i in range(X_.shape[0]):
   hess += self.X_XT[i] * p[i] * (1 - p[i])
  return hess

 def newton_method(self, X_, y):
  """Newton Method to calculate weight
  :param X_: shape = (n_samples + 1, n_features)
  :param y: shape = (n_samples)
  :return: None
  """
  self.weight = np.ones(X_.shape[1])
  self.X_XT = []
  for i in range(X_.shape[0]):
   t = X_[i, :].reshape((-1, 1))
   self.X_XT.append(t @ t.T)

  for _ in range(self.max_epoch):
   p = self.p_func(X_)
   diff = self.diff(X_, y, p)
   hess = self.hess_mat(X_, p)
   new_weight = self.weight - (np.linalg.inv(hess) @ diff.reshape((-1, 1))).flatten()

   if np.linalg.norm(new_weight - self.weight) <= self.error:
    break
   self.weight = new_weight

 def fit(self, X, y):
  """
  :param X_: shape = (n_samples, n_features)
  :param y: shape = (n_samples)
  :return: self
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  self.newton_method(X_, y)
  return self

 def predict(self, X) -> np.array:
  """
  :param X: shape = (n_samples, n_features] 
  :return: shape = (n_samples]
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  return self.sign(self.p_func(X_))

测试代码

import matplotlib.pyplot as plt
import sklearn.datasets

def plot_decision_boundary(pred_func, X, y, title=None):
 """分类器画图函数,可画出样本点和决策边界
 :param pred_func: predict函数
 :param X: 训练集X
 :param y: 训练集Y
 :return: None
 """

 # Set min and max values and give it some padding
 x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
 h = 0.01
 # Generate a grid of points with distance h between them
 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
 # Predict the function value for the whole gid
 Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
 Z = Z.reshape(xx.shape)
 # Plot the contour and training examples
 plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
 plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
 if title:
  plt.title(title)
 plt.show()

效果

python 牛顿法实现逻辑回归(Logistic Regression)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是python 牛顿法实现逻辑回归(Logistic Regression)的详细内容,更多关于python 逻辑回归的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python中Django框架利用url来控制登录的方法
Jul 25 Python
Python打包文件夹的方法小结(zip,tar,tar.gz等)
Sep 18 Python
Python3中使用urllib的方法详解(header,代理,超时,认证,异常处理)
Sep 21 Python
python入门基础之用户输入与模块初认识
Nov 14 Python
python+POP3实现批量下载邮件附件
Jun 19 Python
在Python文件中指定Python解释器的方法
Feb 18 Python
opencv实现简单人脸识别
Feb 19 Python
解决Pycharm 导入其他文件夹源码的2种方法
Feb 12 Python
浅谈优化Django ORM中的性能问题
Jul 09 Python
Django REST 异常处理详解
Jul 15 Python
Pygame如何使用精灵和碰撞检测
Nov 17 Python
Python+pyaudio实现音频控制示例详解
Jul 23 Python
PyCharm 2020.2.2 x64 下载并安装的详细教程
Oct 15 #Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
Oct 15 #Python
Python在centos7.6上安装python3.9的详细教程(默认python版本为2.7.5)
Oct 15 #Python
Pycharm编辑器功能之代码折叠效果的实现代码
Oct 15 #Python
如何用Python 实现全连接神经网络(Multi-layer Perceptron)
Oct 15 #Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
You might like
用PHP的超级变量$_GET获取HTML表单(Form) 数据
2011/05/07 PHP
JavaScript创建命名空间的5种写法
2014/06/24 PHP
实例讲解PHP面向对象之多态
2014/08/20 PHP
PHP图像处理之imagecreate、imagedestroy函数介绍
2014/11/19 PHP
PHP定时执行任务实现方法详解(Timer)
2015/07/30 PHP
WordPress主题制作之模板文件的引入方法
2015/12/28 PHP
PHP 生成微信红包代码简单
2016/03/25 PHP
一个无限级XML绑定跨框架菜单(For IE)
2007/01/27 Javascript
用jQuery模拟页面加载进度条的实现代码
2011/12/19 Javascript
jQuery Deferred和Promise创建响应式应用程序详细介绍
2013/03/05 Javascript
ExtJS如何设置与获取radio控件的选取状态
2014/01/22 Javascript
jQuery对下拉框,单选框,多选框的操作
2014/02/21 Javascript
JavaScript中定义函数的三种方法
2015/03/12 Javascript
jQuery Form 表单提交插件之formSerialize,fieldSerialize,fieldValue,resetForm,clearForm,clearFields的应用
2016/01/23 Javascript
对Js OOP编程 创建对象的一些全面理解
2016/07/26 Javascript
angularJs使用$watch和$filter过滤器制作搜索筛选实例
2017/06/01 Javascript
Vue非父子组件通信详解
2017/06/12 Javascript
JS设计模式之数据访问对象模式的实例讲解
2017/09/30 Javascript
angular2/ionic2 实现搜索结果中的搜索关键字高亮的示例
2018/08/17 Javascript
详解Node.js读写中文内容文件操作
2018/10/10 Javascript
Angular设置别名alias的方法
2018/11/08 Javascript
Python实现分割文件及合并文件的方法
2015/07/10 Python
在Django的模型中执行原始SQL查询的方法
2015/07/21 Python
Python通过matplotlib画双层饼图及环形图简单示例
2017/12/15 Python
python3使用smtplib实现发送邮件功能
2018/05/22 Python
Python 一键获取百度网盘提取码的方法
2019/08/01 Python
python中栈的原理及实现方法示例
2019/11/27 Python
python开根号实例讲解
2020/08/30 Python
纯CSS3实现的8种Loading动画效果
2014/07/05 HTML / CSS
创建省级文明单位实施方案
2014/02/27 职场文书
机关干部三严三实心得体会
2014/10/13 职场文书
三提三创主题教育活动查摆整改措施
2014/10/25 职场文书
党的群众路线教育实践活动制度建设计划
2014/11/03 职场文书
2015年新学期寄语
2015/02/26 职场文书
三八节活动简报
2015/07/20 职场文书
Pyhton爬虫知识之正则表达式详解
2022/04/01 Python