python 牛顿法实现逻辑回归(Logistic Regression)


Posted in Python onOctober 15, 2020

本文采用的训练方法是牛顿法(Newton Method)。

代码

import numpy as np

class LogisticRegression(object):
 """
 Logistic Regression Classifier training by Newton Method
 """

 def __init__(self, error: float = 0.7, max_epoch: int = 100):
  """
  :param error: float, if the distance between new weight and 
      old weight is less than error, the process 
      of traing will break.
  :param max_epoch: if training epoch >= max_epoch the process 
       of traing will break.
  """
  self.error = error
  self.max_epoch = max_epoch
  self.weight = None
  self.sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)

 def p_func(self, X_):
  """Get P(y=1 | x)
  :param X_: shape = (n_samples + 1, n_features)
  :return: shape = (n_samples)
  """
  tmp = np.exp(self.weight @ X_.T)
  return tmp / (1 + tmp)

 def diff(self, X_, y, p):
  """Get derivative
  :param X_: shape = (n_samples, n_features + 1) 
  :param y: shape = (n_samples)
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1) first derivative
  """
  return -(y - p) @ X_

 def hess_mat(self, X_, p):
  """Get Hessian Matrix
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1, n_features + 1) second derivative
  """
  hess = np.zeros((X_.shape[1], X_.shape[1]))
  for i in range(X_.shape[0]):
   hess += self.X_XT[i] * p[i] * (1 - p[i])
  return hess

 def newton_method(self, X_, y):
  """Newton Method to calculate weight
  :param X_: shape = (n_samples + 1, n_features)
  :param y: shape = (n_samples)
  :return: None
  """
  self.weight = np.ones(X_.shape[1])
  self.X_XT = []
  for i in range(X_.shape[0]):
   t = X_[i, :].reshape((-1, 1))
   self.X_XT.append(t @ t.T)

  for _ in range(self.max_epoch):
   p = self.p_func(X_)
   diff = self.diff(X_, y, p)
   hess = self.hess_mat(X_, p)
   new_weight = self.weight - (np.linalg.inv(hess) @ diff.reshape((-1, 1))).flatten()

   if np.linalg.norm(new_weight - self.weight) <= self.error:
    break
   self.weight = new_weight

 def fit(self, X, y):
  """
  :param X_: shape = (n_samples, n_features)
  :param y: shape = (n_samples)
  :return: self
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  self.newton_method(X_, y)
  return self

 def predict(self, X) -> np.array:
  """
  :param X: shape = (n_samples, n_features] 
  :return: shape = (n_samples]
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  return self.sign(self.p_func(X_))

测试代码

import matplotlib.pyplot as plt
import sklearn.datasets

def plot_decision_boundary(pred_func, X, y, title=None):
 """分类器画图函数,可画出样本点和决策边界
 :param pred_func: predict函数
 :param X: 训练集X
 :param y: 训练集Y
 :return: None
 """

 # Set min and max values and give it some padding
 x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
 h = 0.01
 # Generate a grid of points with distance h between them
 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
 # Predict the function value for the whole gid
 Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
 Z = Z.reshape(xx.shape)
 # Plot the contour and training examples
 plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
 plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
 if title:
  plt.title(title)
 plt.show()

效果

python 牛顿法实现逻辑回归(Logistic Regression)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是python 牛顿法实现逻辑回归(Logistic Regression)的详细内容,更多关于python 逻辑回归的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python中函数的多种格式和使用实例及小技巧
Apr 13 Python
Python开发SQLite3数据库相关操作详解【连接,查询,插入,更新,删除,关闭等】
Jul 27 Python
Python使用jsonpath-rw模块处理Json对象操作示例
Jul 31 Python
CentOS 7下安装Python3.6 及遇到的问题小结
Nov 08 Python
python3学生名片管理v2.0版
Nov 29 Python
对Python的zip函数妙用,旋转矩阵详解
Dec 13 Python
python判断计算机是否有网络连接的实例
Dec 15 Python
介绍一款python类型检查工具pyright(推荐)
Jul 03 Python
python如何读取bin文件并下发串口
Jul 05 Python
Python递归及尾递归优化操作实例分析
Feb 01 Python
django 前端页面如何实现显示前N条数据
Mar 16 Python
python+selenium+Chrome options参数的使用
Mar 18 Python
PyCharm 2020.2.2 x64 下载并安装的详细教程
Oct 15 #Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
Oct 15 #Python
Python在centos7.6上安装python3.9的详细教程(默认python版本为2.7.5)
Oct 15 #Python
Pycharm编辑器功能之代码折叠效果的实现代码
Oct 15 #Python
如何用Python 实现全连接神经网络(Multi-layer Perceptron)
Oct 15 #Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
You might like
PHP中根据IP地址判断城市实现城市切换或跳转代码
2012/09/04 PHP
Session的工作机制详解和安全性问题(PHP实例讲解)
2014/04/10 PHP
JavaScript 动态生成方法的例子
2009/07/22 Javascript
textarea中的手动换行处理的jquery代码
2011/02/26 Javascript
读jQuery之七 判断点击了鼠标哪个键的代码
2011/06/21 Javascript
浅析onsubmit校验表单时利用ajax的return false无效问题
2013/07/10 Javascript
浅析Javascript中“==”与“===”的区别
2014/12/23 Javascript
js实现照片墙功能实例
2015/02/05 Javascript
js实现数组冒泡排序、快速排序原理
2016/03/08 Javascript
深入理解JavaScript中的浮点数
2016/05/18 Javascript
JS 调用微信扫一扫功能
2016/12/22 Javascript
原生JS+Canvas实现五子棋游戏实例
2017/06/19 Javascript
angular2 ng build部署后base文件路径问题详细解答
2017/07/15 Javascript
基于Datatables跳转到指定页的简单实例
2017/11/09 Javascript
tween.js缓动补间动画算法示例
2018/02/13 Javascript
基于vue打包后字体和图片资源失效问题的解决方法
2018/03/06 Javascript
AngularJS 前台分页实现的示例代码
2018/06/07 Javascript
用webpack4开发小程序的实现方法
2019/06/04 Javascript
解决vue项目F5刷新mounted里的函数不执行问题
2019/11/05 Javascript
[16:19]教你分分钟做大人——风暴之灵
2015/03/11 DOTA
Pyramid将models.py文件的内容分布到多个文件的方法
2013/11/27 Python
跟老齐学Python之永远强大的函数
2014/09/14 Python
python通过字典dict判断指定键值是否存在的方法
2015/03/21 Python
python实现多进程代码示例
2018/10/31 Python
在Keras中实现保存和加载权重及模型结构
2020/06/15 Python
python中翻译功能translate模块实现方法
2020/12/17 Python
Java程序员常见面试题
2015/07/16 面试题
秘书英文求职信
2014/04/16 职场文书
医院2014国庆节活动策划方案
2014/09/21 职场文书
2014年仓库管理工作总结
2014/12/17 职场文书
工作自我推荐信范文
2015/03/25 职场文书
大学推普周活动总结
2015/05/07 职场文书
欠条格式范本
2015/07/03 职场文书
七夕情人节问候语
2015/11/11 职场文书
大学生村官驻村工作心得体会
2016/01/23 职场文书
python index() 与 rindex() 方法的使用示例详解
2022/12/24 Python