python 牛顿法实现逻辑回归(Logistic Regression)


Posted in Python onOctober 15, 2020

本文采用的训练方法是牛顿法(Newton Method)。

代码

import numpy as np

class LogisticRegression(object):
 """
 Logistic Regression Classifier training by Newton Method
 """

 def __init__(self, error: float = 0.7, max_epoch: int = 100):
  """
  :param error: float, if the distance between new weight and 
      old weight is less than error, the process 
      of traing will break.
  :param max_epoch: if training epoch >= max_epoch the process 
       of traing will break.
  """
  self.error = error
  self.max_epoch = max_epoch
  self.weight = None
  self.sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)

 def p_func(self, X_):
  """Get P(y=1 | x)
  :param X_: shape = (n_samples + 1, n_features)
  :return: shape = (n_samples)
  """
  tmp = np.exp(self.weight @ X_.T)
  return tmp / (1 + tmp)

 def diff(self, X_, y, p):
  """Get derivative
  :param X_: shape = (n_samples, n_features + 1) 
  :param y: shape = (n_samples)
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1) first derivative
  """
  return -(y - p) @ X_

 def hess_mat(self, X_, p):
  """Get Hessian Matrix
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1, n_features + 1) second derivative
  """
  hess = np.zeros((X_.shape[1], X_.shape[1]))
  for i in range(X_.shape[0]):
   hess += self.X_XT[i] * p[i] * (1 - p[i])
  return hess

 def newton_method(self, X_, y):
  """Newton Method to calculate weight
  :param X_: shape = (n_samples + 1, n_features)
  :param y: shape = (n_samples)
  :return: None
  """
  self.weight = np.ones(X_.shape[1])
  self.X_XT = []
  for i in range(X_.shape[0]):
   t = X_[i, :].reshape((-1, 1))
   self.X_XT.append(t @ t.T)

  for _ in range(self.max_epoch):
   p = self.p_func(X_)
   diff = self.diff(X_, y, p)
   hess = self.hess_mat(X_, p)
   new_weight = self.weight - (np.linalg.inv(hess) @ diff.reshape((-1, 1))).flatten()

   if np.linalg.norm(new_weight - self.weight) <= self.error:
    break
   self.weight = new_weight

 def fit(self, X, y):
  """
  :param X_: shape = (n_samples, n_features)
  :param y: shape = (n_samples)
  :return: self
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  self.newton_method(X_, y)
  return self

 def predict(self, X) -> np.array:
  """
  :param X: shape = (n_samples, n_features] 
  :return: shape = (n_samples]
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  return self.sign(self.p_func(X_))

测试代码

import matplotlib.pyplot as plt
import sklearn.datasets

def plot_decision_boundary(pred_func, X, y, title=None):
 """分类器画图函数,可画出样本点和决策边界
 :param pred_func: predict函数
 :param X: 训练集X
 :param y: 训练集Y
 :return: None
 """

 # Set min and max values and give it some padding
 x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
 h = 0.01
 # Generate a grid of points with distance h between them
 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
 # Predict the function value for the whole gid
 Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
 Z = Z.reshape(xx.shape)
 # Plot the contour and training examples
 plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
 plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
 if title:
  plt.title(title)
 plt.show()

效果

python 牛顿法实现逻辑回归(Logistic Regression)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是python 牛顿法实现逻辑回归(Logistic Regression)的详细内容,更多关于python 逻辑回归的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python中sets模块的用法实例
Sep 30 Python
python执行子进程实现进程间通信的方法
Jun 02 Python
python控制台中实现进度条功能
Nov 10 Python
Python pymongo模块用法示例
Mar 31 Python
pytorch 调整某一维度数据顺序的方法
Dec 08 Python
pyqt5利用pyqtDesigner实现登录界面
Mar 28 Python
python打开使用的方法
Sep 30 Python
Python实现报警信息实时发送至邮箱功能(实例代码)
Nov 11 Python
Python使用matplotlib绘制Logistic曲线操作示例
Nov 28 Python
如何使用python socket模块实现简单的文件下载
Sep 04 Python
python中类与对象之间的关系详解
Dec 16 Python
python爬取豆瓣电影排行榜(requests)的示例代码
Feb 18 Python
PyCharm 2020.2.2 x64 下载并安装的详细教程
Oct 15 #Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
Oct 15 #Python
Python在centos7.6上安装python3.9的详细教程(默认python版本为2.7.5)
Oct 15 #Python
Pycharm编辑器功能之代码折叠效果的实现代码
Oct 15 #Python
如何用Python 实现全连接神经网络(Multi-layer Perceptron)
Oct 15 #Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
You might like
shopex中集成的站长统计功能的代码简单分析
2011/08/11 PHP
php导入csv文件碰到乱码问题的解决方法
2014/02/10 PHP
PHP对象递归引用造成内存泄漏分析
2014/08/28 PHP
PHP实现获取ip地址的5种方法,以及插入用户登录日志操作示例
2019/02/28 PHP
laravel通用化的CURD的实现
2019/12/13 PHP
Javascript实例教程(19) 使用HoTMetal(5)
2006/12/23 Javascript
Jquery实战_读书笔记1—选择jQuery
2010/01/22 Javascript
理解Javascript_06_理解对象的创建过程
2010/10/15 Javascript
JavaScript 注册事件代码
2011/01/27 Javascript
jquery果冻抖动效果实现方法
2015/01/15 Javascript
js实现在网页上简单显示时间的方法
2015/03/02 Javascript
JS判断网页广告是否被浏览器拦截过滤的代码
2015/04/05 Javascript
jQuery实现仿路边灯箱广告图片轮播效果
2015/04/15 Javascript
js实现具有高亮显示效果的多级菜单代码
2015/09/01 Javascript
微信小程序 教程之WXSS
2016/10/18 Javascript
Bootstrap导航条学习使用(二)
2017/02/08 Javascript
整理关于Bootstrap警示框的慕课笔记
2017/03/29 Javascript
js轮播图透明度切换(带上下页和底部圆点切换)
2017/04/27 Javascript
详解vue服务端渲染(SSR)初探
2017/06/19 Javascript
Angular js 实现添加用户、修改密码、敏感字、下拉菜单的综合操作方法
2017/10/24 Javascript
javascript原生封装一个淡入淡出效果的函数测试实例代码
2018/03/19 Javascript
vue.js编译时给生成的文件增加版本号
2018/09/17 Javascript
p5.js实现动态图形临摹
2019/10/23 Javascript
[47:43]完美世界DOTA2联赛PWL S3 Magama vs GXR 第二场 12.19
2020/12/24 DOTA
python数据结构之列表和元组的详解
2017/09/23 Python
python 弹窗提示警告框MessageBox的实例
2019/06/18 Python
Python3爬虫关于识别点触点选验证码的实例讲解
2020/07/30 Python
浅谈Python3中print函数的换行
2020/08/05 Python
html5 canvas实现给图片添加平铺水印
2019/08/20 HTML / CSS
Ootori在线按摩椅店:一家专业的按摩椅制造商
2019/04/10 全球购物
大学生自我评价怎样写好
2013/10/23 职场文书
资金主管岗位职责范本
2014/03/04 职场文书
个人整改方案范文
2014/10/25 职场文书
党支部创先争优公开承诺书
2015/04/30 职场文书
党员电教片《信仰》心得体会
2016/01/15 职场文书
自荐信范文
2019/05/20 职场文书