python代码实现逻辑回归logistic原理


Posted in Python onAugust 07, 2019

Logistic Regression Classifier逻辑回归主要思想就是用最大似然概率方法构建出方程,为最大化方程,利用牛顿梯度上升求解方程参数。

  • 优点:计算代价不高,易于理解和实现。
  • 缺点:容易欠拟合,分类精度可能不高。
  • 使用数据类型:数值型和标称型数据。

介绍逻辑回归之前,我们先看一问题,有个黑箱,里面有白球和黑球,如何判断它们的比例。

我们从里面抓3个球,2个黑球,1个白球。这时候,有人就直接得出了黑球67%,白球占比33%。这个时候,其实这个人使用了最大似然概率的思想,通俗来讲,当黑球是67%的占比的时候,我们抓3个球,出现2黑1白的概率最大。我们直接用公式来说明。

假设黑球占比为P,白球为1-P。于是我们要求解MAX(PP(1-P)),显而易见P=67%(求解方法:对方程求导,使导数为0的P值即为最优解)

我们看逻辑回归,解决的是二分类问题,是不是和上面黑球白球问题很像,是的,逻辑回归也是最大似然概率来求解。

假设我们有n个独立的训练样本{(x1, y1) ,(x2, y2),…, (xn, yn)},y={0, 1}。那每一个观察到的样本(xi, yi)出现的概率是:

python代码实现逻辑回归logistic原理

上面为什么是这样呢?当y=1的时候,后面那一项是不是没有了,那就只剩下x属于1类的概率,当y=0的时候,第一项是不是没有了,那就只剩下后面那个x属于0的概率(1减去x属于1的概率)。所以不管y是0还是1,上面得到的数,都是(x, y)出现的概率。那我们的整个样本集,也就是n个独立的样本出现的似然函数为(因为每个样本都是独立的,所以n个样本出现的概率就是他们各自出现的概率相乘):

python代码实现逻辑回归logistic原理

这里我们稍微变换下L(θ):取自然对数,然后化简(不要看到一堆公式就害怕哦,很简单的哦,只需要耐心一点点,自己动手推推就知道了。注:有xi的时候,表示它是第i个样本,下面没有做区分了,相信你的眼睛是雪亮的),得到:

python代码实现逻辑回归logistic原理

其中第三步到第四步使用了下面替换。

python代码实现逻辑回归logistic原理

这时候为求最大值,对L(θ)对θ求导,得到:

python代码实现逻辑回归logistic原理

然后我们令该导数为0,即可求出最优解。但是这个方程是无法解析求解(这里就不证明了)。
最后问题变成了,求解参数使方程L最大化,求解参数的方法梯度上升法(原理这里不解释了,看详细的代码的计算方式应该更容易理解些)。

根据这个转换公式

python代码实现逻辑回归logistic原理

我们代入参数和特征,求P,也就是发生1的概率。

python代码实现逻辑回归logistic原理

上面这个也就是常提及的sigmoid函数,俗称激活函数,最后用于分类(若P(y=1|x;Θ\ThetaΘ )大于0.5,则判定为1)。

下面是详细的逻辑回归代码,代码比较简单,主要是要理解上面的算法思想。个人建议,可以结合代码看一步一步怎么算的,然后对比上面推导公式,可以让人更加容易理解,并加深印象。

from numpy import *
filename='...\\testSet.txt' #文件目录
def loadDataSet():  #读取数据(这里只有两个特征)
  dataMat = []
  labelMat = []
  fr = open(filename)
  for line in fr.readlines():
    lineArr = line.strip().split()
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])  #前面的1,表示方程的常量。比如两个特征X1,X2,共需要三个参数,W1+W2*X1+W3*X2
    labelMat.append(int(lineArr[2]))
  return dataMat,labelMat

def sigmoid(inX): #sigmoid函数
  return 1.0/(1+exp(-inX))

def gradAscent(dataMat, labelMat): #梯度上升求最优参数
  dataMatrix=mat(dataMat) #将读取的数据转换为矩阵
  classLabels=mat(labelMat).transpose() #将读取的数据转换为矩阵
  m,n = shape(dataMatrix)
  alpha = 0.001 #设置梯度的阀值,该值越大梯度上升幅度越大
  maxCycles = 500 #设置迭代的次数,一般看实际数据进行设定,有些可能200次就够了
  weights = ones((n,1)) #设置初始的参数,并都赋默认值为1。注意这里权重以矩阵形式表示三个参数。
  for k in range(maxCycles):
    h = sigmoid(dataMatrix*weights)
    error = (classLabels - h)   #求导后差值
    weights = weights + alpha * dataMatrix.transpose()* error #迭代更新权重
  return weights

def stocGradAscent0(dataMat, labelMat): #随机梯度上升,当数据量比较大时,每次迭代都选择全量数据进行计算,计算量会非常大。所以采用每次迭代中一次只选择其中的一行数据进行更新权重。
  dataMatrix=mat(dataMat)
  classLabels=labelMat
  m,n=shape(dataMatrix)
  alpha=0.01
  maxCycles = 500
  weights=ones((n,1))
  for k in range(maxCycles):
    for i in range(m): #遍历计算每一行
      h = sigmoid(sum(dataMatrix[i] * weights))
      error = classLabels[i] - h
      weights = weights + alpha * error * dataMatrix[i].transpose()
  return weights

def stocGradAscent1(dataMat, labelMat): #改进版随机梯度上升,在每次迭代中随机选择样本来更新权重,并且随迭代次数增加,权重变化越小。
  dataMatrix=mat(dataMat)
  classLabels=labelMat
  m,n=shape(dataMatrix)
  weights=ones((n,1))
  maxCycles=500
  for j in range(maxCycles): #迭代
    dataIndex=[i for i in range(m)]
    for i in range(m): #随机遍历每一行
      alpha=4/(1+j+i)+0.0001 #随迭代次数增加,权重变化越小。
      randIndex=int(random.uniform(0,len(dataIndex))) #随机抽样
      h=sigmoid(sum(dataMatrix[randIndex]*weights))
      error=classLabels[randIndex]-h
      weights=weights+alpha*error*dataMatrix[randIndex].transpose()
      del(dataIndex[randIndex]) #去除已经抽取的样本
  return weights

def plotBestFit(weights): #画出最终分类的图
  import matplotlib.pyplot as plt
  dataMat,labelMat=loadDataSet()
  dataArr = array(dataMat)
  n = shape(dataArr)[0]
  xcord1 = []; ycord1 = []
  xcord2 = []; ycord2 = []
  for i in range(n):
    if int(labelMat[i])== 1:
      xcord1.append(dataArr[i,1])
      ycord1.append(dataArr[i,2])
    else:
      xcord2.append(dataArr[i,1])
      ycord2.append(dataArr[i,2])
  fig = plt.figure()
  ax = fig.add_subplot(111)
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
  ax.scatter(xcord2, ycord2, s=30, c='green')
  x = arange(-3.0, 3.0, 0.1)
  y = (-weights[0]-weights[1]*x)/weights[2]
  ax.plot(x, y)
  plt.xlabel('X1')
  plt.ylabel('X2')
  plt.show()

def main():
  dataMat, labelMat = loadDataSet()
  weights=gradAscent(dataMat, labelMat).getA()
  plotBestFit(weights)

if __name__=='__main__':
  main()

跑完代码结果:

python代码实现逻辑回归logistic原理

当然,还可以换随机梯度上升和改进的随机梯度上升算法试试,效果都还不错。

下面是代码使用的数据,可以直接复制本地text里面,跑上面代码。

-0.017612	14.053064	0
-1.395634	4.662541	1
-0.752157	6.538620	0
-1.322371	7.152853	0
0.423363	11.054677	0
0.406704	7.067335	1
0.667394	12.741452	0
-2.460150	6.866805	1
0.569411	9.548755	0
-0.026632	10.427743	0
0.850433	6.920334	1
1.347183	13.175500	0
1.176813	3.167020	1
-1.781871	9.097953	0
-0.566606	5.749003	1
0.931635	1.589505	1
-0.024205	6.151823	1
-0.036453	2.690988	1
-0.196949	0.444165	1
1.014459	5.754399	1
1.985298	3.230619	1
-1.693453	-0.557540	1
-0.576525	11.778922	0
-0.346811	-1.678730	1
-2.124484	2.672471	1
1.217916	9.597015	0
-0.733928	9.098687	0
-3.642001	-1.618087	1
0.315985	3.523953	1
1.416614	9.619232	0
-0.386323	3.989286	1
0.556921	8.294984	1
1.224863	11.587360	0
-1.347803	-2.406051	1
1.196604	4.951851	1
0.275221	9.543647	0
0.470575	9.332488	0
-1.889567	9.542662	0
-1.527893	12.150579	0
-1.185247	11.309318	0
-0.445678	3.297303	1
1.042222	6.105155	1
-0.618787	10.320986	0
1.152083	0.548467	1
0.828534	2.676045	1
-1.237728	10.549033	0
-0.683565	-2.166125	1
0.229456	5.921938	1
-0.959885	11.555336	0
0.492911	10.993324	0
0.184992	8.721488	0
-0.355715	10.325976	0
-0.397822	8.058397	0
0.824839	13.730343	0
1.507278	5.027866	1
0.099671	6.835839	1
-0.344008	10.717485	0
1.785928	7.718645	1
-0.918801	11.560217	0
-0.364009	4.747300	1
-0.841722	4.119083	1
0.490426	1.960539	1
-0.007194	9.075792	0
0.356107	12.447863	0
0.342578	12.281162	0
-0.810823	-1.466018	1
2.530777	6.476801	1
1.296683	11.607559	0
0.475487	12.040035	0
-0.783277	11.009725	0
0.074798	11.023650	0
-1.337472	0.468339	1
-0.102781	13.763651	0
-0.147324	2.874846	1
0.518389	9.887035	0
1.015399	7.571882	0
-1.658086	-0.027255	1
1.319944	2.171228	1
2.056216	5.019981	1
-0.851633	4.375691	1
-1.510047	6.061992	0
-1.076637	-3.181888	1
1.821096	10.283990	0
3.010150	8.401766	1
-1.099458	1.688274	1
-0.834872	-1.733869	1
-0.846637	3.849075	1
1.400102	12.628781	0
1.752842	5.468166	1
0.078557	0.059736	1
0.089392	-0.715300	1
1.825662	12.693808	0
0.197445	9.744638	0
0.126117	0.922311	1
-0.679797	1.220530	1
0.677983	2.556666	1
0.761349	10.693862	0
-2.168791	0.143632	1
1.388610	9.341997	0
0.317029	14.739025	0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python之ReportLab绘制条形码和二维码的实例
Jan 15 Python
python selenium 获取标签的属性值、内容、状态方法
Jun 22 Python
python 常见字符串与函数的用法详解
Nov 23 Python
在Django中URL正则表达式匹配的方法
Dec 20 Python
python构建基础的爬虫教学
Dec 23 Python
Django Rest framework频率原理与限制
Jul 26 Python
利用python在大量数据文件下删除某一行的例子
Aug 21 Python
python线程的几种创建方式详解
Aug 29 Python
python安装scipy的步骤解析
Sep 28 Python
python 图片二值化处理(处理后为纯黑白的图片)
Nov 01 Python
Python selenium的基本使用方法分析
Dec 21 Python
Python Pygame实战在打砖块游戏的实现
Mar 17 Python
Python在cmd上打印彩色文字实现过程详解
Aug 07 #Python
Python如何调用外部系统命令
Aug 07 #Python
PyQt5通信机制 信号与槽详解
Aug 07 #Python
python 使用socket传输图片视频等文件的实现方式
Aug 07 #Python
python获取Pandas列名的几种方法
Aug 07 #Python
python 提取文件指定列的方法示例
Aug 07 #Python
PyQt Qt Designer工具的布局管理详解
Aug 07 #Python
You might like
使用 MySQL 开始 PHP 会话
2006/12/21 PHP
PHP如何实现跨域
2016/05/30 PHP
javascript基于jQuery的表格悬停变色/恢复,表格点击变色/恢复,点击行选Checkbox
2008/08/05 Javascript
JavaScript flash复制库类 Zero Clipboard
2011/01/17 Javascript
eclipse如何忽略js文件报错(附图)
2013/10/30 Javascript
javascript设置金额样式转换保留两位小数示例代码
2013/12/04 Javascript
Nodejs学习笔记之Stream模块
2015/01/13 NodeJs
JS基于面向对象实现的放烟花效果
2015/05/07 Javascript
Bootstrap入门教程一Hello Bootstrap初识
2017/03/02 Javascript
Vue.js划分组件的方法
2017/10/29 Javascript
Javascript中prototype与__proto__的关系详解
2018/03/11 Javascript
AngularJs分页插件使用详解
2018/06/30 Javascript
JS canvas绘制五子棋的棋盘
2020/05/28 Javascript
vue-cli脚手架搭建的项目去除eslint验证的方法
2018/09/29 Javascript
vue实现条件叠加搜索的解决方法
2019/05/28 Javascript
vuex + keep-alive实现tab标签页面缓存功能
2019/10/17 Javascript
vue实现在线预览pdf文件和下载(pdf.js)
2019/11/26 Javascript
Vue+Element ui 根据后台返回数据设置动态表头操作
2020/09/21 Javascript
Vue 简单实现前端权限控制的示例
2020/12/25 Vue.js
[05:59]2018DOTA2国际邀请赛寻真——只为胜利的Secret
2018/08/13 DOTA
浅析Python中MySQLdb的事务处理功能
2016/09/21 Python
在 Python 应用中使用 MongoDB的方法
2017/01/05 Python
Python3进制之间的转换代码实例
2019/08/24 Python
wxpython实现按钮切换界面的方法
2019/11/19 Python
利用python实现AR教程
2019/11/20 Python
python利用蒙版抠图(使用PIL.Image和cv2)输出透明背景图
2020/08/04 Python
Python经纬度坐标转换为距离及角度的实现
2020/11/01 Python
Python创建自己的加密货币的示例
2021/03/01 Python
一个SQL面试题
2014/08/21 面试题
Java面试题:为什么要用Java
2012/05/11 面试题
大学生职业生涯规划范文
2013/12/31 职场文书
个人自我评价范文
2014/02/05 职场文书
前处理组长岗位职责
2014/03/01 职场文书
夫妻分居协议书范文
2014/11/26 职场文书
2014工程部年度工作总结
2014/12/17 职场文书
蔬果开业典礼发言稿应该怎么写?
2019/09/03 职场文书