python代码实现逻辑回归logistic原理


Posted in Python onAugust 07, 2019

Logistic Regression Classifier逻辑回归主要思想就是用最大似然概率方法构建出方程,为最大化方程,利用牛顿梯度上升求解方程参数。

  • 优点:计算代价不高,易于理解和实现。
  • 缺点:容易欠拟合,分类精度可能不高。
  • 使用数据类型:数值型和标称型数据。

介绍逻辑回归之前,我们先看一问题,有个黑箱,里面有白球和黑球,如何判断它们的比例。

我们从里面抓3个球,2个黑球,1个白球。这时候,有人就直接得出了黑球67%,白球占比33%。这个时候,其实这个人使用了最大似然概率的思想,通俗来讲,当黑球是67%的占比的时候,我们抓3个球,出现2黑1白的概率最大。我们直接用公式来说明。

假设黑球占比为P,白球为1-P。于是我们要求解MAX(PP(1-P)),显而易见P=67%(求解方法:对方程求导,使导数为0的P值即为最优解)

我们看逻辑回归,解决的是二分类问题,是不是和上面黑球白球问题很像,是的,逻辑回归也是最大似然概率来求解。

假设我们有n个独立的训练样本{(x1, y1) ,(x2, y2),…, (xn, yn)},y={0, 1}。那每一个观察到的样本(xi, yi)出现的概率是:

python代码实现逻辑回归logistic原理

上面为什么是这样呢?当y=1的时候,后面那一项是不是没有了,那就只剩下x属于1类的概率,当y=0的时候,第一项是不是没有了,那就只剩下后面那个x属于0的概率(1减去x属于1的概率)。所以不管y是0还是1,上面得到的数,都是(x, y)出现的概率。那我们的整个样本集,也就是n个独立的样本出现的似然函数为(因为每个样本都是独立的,所以n个样本出现的概率就是他们各自出现的概率相乘):

python代码实现逻辑回归logistic原理

这里我们稍微变换下L(θ):取自然对数,然后化简(不要看到一堆公式就害怕哦,很简单的哦,只需要耐心一点点,自己动手推推就知道了。注:有xi的时候,表示它是第i个样本,下面没有做区分了,相信你的眼睛是雪亮的),得到:

python代码实现逻辑回归logistic原理

其中第三步到第四步使用了下面替换。

python代码实现逻辑回归logistic原理

这时候为求最大值,对L(θ)对θ求导,得到:

python代码实现逻辑回归logistic原理

然后我们令该导数为0,即可求出最优解。但是这个方程是无法解析求解(这里就不证明了)。
最后问题变成了,求解参数使方程L最大化,求解参数的方法梯度上升法(原理这里不解释了,看详细的代码的计算方式应该更容易理解些)。

根据这个转换公式

python代码实现逻辑回归logistic原理

我们代入参数和特征,求P,也就是发生1的概率。

python代码实现逻辑回归logistic原理

上面这个也就是常提及的sigmoid函数,俗称激活函数,最后用于分类(若P(y=1|x;Θ\ThetaΘ )大于0.5,则判定为1)。

下面是详细的逻辑回归代码,代码比较简单,主要是要理解上面的算法思想。个人建议,可以结合代码看一步一步怎么算的,然后对比上面推导公式,可以让人更加容易理解,并加深印象。

from numpy import *
filename='...\\testSet.txt' #文件目录
def loadDataSet():  #读取数据(这里只有两个特征)
  dataMat = []
  labelMat = []
  fr = open(filename)
  for line in fr.readlines():
    lineArr = line.strip().split()
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])  #前面的1,表示方程的常量。比如两个特征X1,X2,共需要三个参数,W1+W2*X1+W3*X2
    labelMat.append(int(lineArr[2]))
  return dataMat,labelMat

def sigmoid(inX): #sigmoid函数
  return 1.0/(1+exp(-inX))

def gradAscent(dataMat, labelMat): #梯度上升求最优参数
  dataMatrix=mat(dataMat) #将读取的数据转换为矩阵
  classLabels=mat(labelMat).transpose() #将读取的数据转换为矩阵
  m,n = shape(dataMatrix)
  alpha = 0.001 #设置梯度的阀值,该值越大梯度上升幅度越大
  maxCycles = 500 #设置迭代的次数,一般看实际数据进行设定,有些可能200次就够了
  weights = ones((n,1)) #设置初始的参数,并都赋默认值为1。注意这里权重以矩阵形式表示三个参数。
  for k in range(maxCycles):
    h = sigmoid(dataMatrix*weights)
    error = (classLabels - h)   #求导后差值
    weights = weights + alpha * dataMatrix.transpose()* error #迭代更新权重
  return weights

def stocGradAscent0(dataMat, labelMat): #随机梯度上升,当数据量比较大时,每次迭代都选择全量数据进行计算,计算量会非常大。所以采用每次迭代中一次只选择其中的一行数据进行更新权重。
  dataMatrix=mat(dataMat)
  classLabels=labelMat
  m,n=shape(dataMatrix)
  alpha=0.01
  maxCycles = 500
  weights=ones((n,1))
  for k in range(maxCycles):
    for i in range(m): #遍历计算每一行
      h = sigmoid(sum(dataMatrix[i] * weights))
      error = classLabels[i] - h
      weights = weights + alpha * error * dataMatrix[i].transpose()
  return weights

def stocGradAscent1(dataMat, labelMat): #改进版随机梯度上升,在每次迭代中随机选择样本来更新权重,并且随迭代次数增加,权重变化越小。
  dataMatrix=mat(dataMat)
  classLabels=labelMat
  m,n=shape(dataMatrix)
  weights=ones((n,1))
  maxCycles=500
  for j in range(maxCycles): #迭代
    dataIndex=[i for i in range(m)]
    for i in range(m): #随机遍历每一行
      alpha=4/(1+j+i)+0.0001 #随迭代次数增加,权重变化越小。
      randIndex=int(random.uniform(0,len(dataIndex))) #随机抽样
      h=sigmoid(sum(dataMatrix[randIndex]*weights))
      error=classLabels[randIndex]-h
      weights=weights+alpha*error*dataMatrix[randIndex].transpose()
      del(dataIndex[randIndex]) #去除已经抽取的样本
  return weights

def plotBestFit(weights): #画出最终分类的图
  import matplotlib.pyplot as plt
  dataMat,labelMat=loadDataSet()
  dataArr = array(dataMat)
  n = shape(dataArr)[0]
  xcord1 = []; ycord1 = []
  xcord2 = []; ycord2 = []
  for i in range(n):
    if int(labelMat[i])== 1:
      xcord1.append(dataArr[i,1])
      ycord1.append(dataArr[i,2])
    else:
      xcord2.append(dataArr[i,1])
      ycord2.append(dataArr[i,2])
  fig = plt.figure()
  ax = fig.add_subplot(111)
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
  ax.scatter(xcord2, ycord2, s=30, c='green')
  x = arange(-3.0, 3.0, 0.1)
  y = (-weights[0]-weights[1]*x)/weights[2]
  ax.plot(x, y)
  plt.xlabel('X1')
  plt.ylabel('X2')
  plt.show()

def main():
  dataMat, labelMat = loadDataSet()
  weights=gradAscent(dataMat, labelMat).getA()
  plotBestFit(weights)

if __name__=='__main__':
  main()

跑完代码结果:

python代码实现逻辑回归logistic原理

当然,还可以换随机梯度上升和改进的随机梯度上升算法试试,效果都还不错。

下面是代码使用的数据,可以直接复制本地text里面,跑上面代码。

-0.017612	14.053064	0
-1.395634	4.662541	1
-0.752157	6.538620	0
-1.322371	7.152853	0
0.423363	11.054677	0
0.406704	7.067335	1
0.667394	12.741452	0
-2.460150	6.866805	1
0.569411	9.548755	0
-0.026632	10.427743	0
0.850433	6.920334	1
1.347183	13.175500	0
1.176813	3.167020	1
-1.781871	9.097953	0
-0.566606	5.749003	1
0.931635	1.589505	1
-0.024205	6.151823	1
-0.036453	2.690988	1
-0.196949	0.444165	1
1.014459	5.754399	1
1.985298	3.230619	1
-1.693453	-0.557540	1
-0.576525	11.778922	0
-0.346811	-1.678730	1
-2.124484	2.672471	1
1.217916	9.597015	0
-0.733928	9.098687	0
-3.642001	-1.618087	1
0.315985	3.523953	1
1.416614	9.619232	0
-0.386323	3.989286	1
0.556921	8.294984	1
1.224863	11.587360	0
-1.347803	-2.406051	1
1.196604	4.951851	1
0.275221	9.543647	0
0.470575	9.332488	0
-1.889567	9.542662	0
-1.527893	12.150579	0
-1.185247	11.309318	0
-0.445678	3.297303	1
1.042222	6.105155	1
-0.618787	10.320986	0
1.152083	0.548467	1
0.828534	2.676045	1
-1.237728	10.549033	0
-0.683565	-2.166125	1
0.229456	5.921938	1
-0.959885	11.555336	0
0.492911	10.993324	0
0.184992	8.721488	0
-0.355715	10.325976	0
-0.397822	8.058397	0
0.824839	13.730343	0
1.507278	5.027866	1
0.099671	6.835839	1
-0.344008	10.717485	0
1.785928	7.718645	1
-0.918801	11.560217	0
-0.364009	4.747300	1
-0.841722	4.119083	1
0.490426	1.960539	1
-0.007194	9.075792	0
0.356107	12.447863	0
0.342578	12.281162	0
-0.810823	-1.466018	1
2.530777	6.476801	1
1.296683	11.607559	0
0.475487	12.040035	0
-0.783277	11.009725	0
0.074798	11.023650	0
-1.337472	0.468339	1
-0.102781	13.763651	0
-0.147324	2.874846	1
0.518389	9.887035	0
1.015399	7.571882	0
-1.658086	-0.027255	1
1.319944	2.171228	1
2.056216	5.019981	1
-0.851633	4.375691	1
-1.510047	6.061992	0
-1.076637	-3.181888	1
1.821096	10.283990	0
3.010150	8.401766	1
-1.099458	1.688274	1
-0.834872	-1.733869	1
-0.846637	3.849075	1
1.400102	12.628781	0
1.752842	5.468166	1
0.078557	0.059736	1
0.089392	-0.715300	1
1.825662	12.693808	0
0.197445	9.744638	0
0.126117	0.922311	1
-0.679797	1.220530	1
0.677983	2.556666	1
0.761349	10.693862	0
-2.168791	0.143632	1
1.388610	9.341997	0
0.317029	14.739025	0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现竖排打印传单手机号码易撕条
Mar 16 Python
python+tkinter编写电脑桌面放大镜程序实例代码
Jan 16 Python
python中的set实现不重复的排序原理
Jan 24 Python
Python字典的基本用法实例分析【创建、增加、获取、修改、删除】
Mar 05 Python
Python参数类型以及常见的坑详解
Jul 08 Python
浅谈python3中input输入的使用
Aug 02 Python
Pytorch Tensor的索引与切片例子
Aug 18 Python
通过字符串导入 Python 模块的方法详解
Oct 27 Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 Python
详解python对象之间的交互
Sep 29 Python
Python基于unittest实现测试用例执行
Nov 25 Python
Python在cmd上打印彩色文字实现过程详解
Aug 07 #Python
Python如何调用外部系统命令
Aug 07 #Python
PyQt5通信机制 信号与槽详解
Aug 07 #Python
python 使用socket传输图片视频等文件的实现方式
Aug 07 #Python
python获取Pandas列名的几种方法
Aug 07 #Python
python 提取文件指定列的方法示例
Aug 07 #Python
PyQt Qt Designer工具的布局管理详解
Aug 07 #Python
You might like
PHP运行SVN命令显示某用户的文件更新记录的代码
2014/01/03 PHP
php示例详解Constructor Prototype Pattern 原型模式
2015/10/15 PHP
实例讲解php将字符串输出到HTML
2019/01/27 PHP
javascript vvorld 在线加密破解方法
2008/11/13 Javascript
判断一个变量是数组Array类型的方法
2013/09/16 Javascript
onmouseover事件和onmouseout事件全面理解
2016/08/15 Javascript
json定义及jquery操作json的方法
2016/10/03 Javascript
Node.js包管理器Yarn的入门介绍与安装
2016/10/17 Javascript
详解nodejs模板引擎制作
2017/06/14 NodeJs
AngularJS使用ng-repeat遍历二维数组元素的方法详解
2017/11/11 Javascript
js实现鼠标点击飘爱心效果
2020/08/19 Javascript
如何检测JavaScript中的死循环示例详解
2020/08/30 Javascript
Vue使用Ref跨层级获取组件的步骤
2021/01/25 Vue.js
[59:15]EG vs LGD 2018国际邀请赛淘汰赛BO3 第一场 8.26
2018/08/29 DOTA
Python标准库之多进程(multiprocessing包)介绍
2014/11/25 Python
python正则表达式中的括号匹配问题
2014/12/14 Python
Python创建二维数组实例(关于list的一个小坑)
2017/11/07 Python
Python subprocess模块常见用法分析
2018/06/12 Python
Django Admin中增加导出CSV功能过程解析
2019/09/04 Python
python调用Matplotlib绘制分布点图
2019/10/18 Python
python打印直角三角形与等腰三角形实例代码
2019/10/20 Python
python scrapy重复执行实现代码详解
2019/12/28 Python
python 项目目录结构设置
2020/02/14 Python
python不相等的两个字符串的 if 条件判断为True详解
2020/03/12 Python
python代码区分大小写吗
2020/06/17 Python
python实现将中文日期转换为数字日期
2020/07/14 Python
重新定义牛仔布,100美元以下:Warp + Weft
2018/07/25 全球购物
室内设计专业个人的自我评价
2013/10/19 职场文书
文明寄语大全
2014/04/11 职场文书
2014公安机关纪律作风整顿思想汇报
2014/09/13 职场文书
2014年督导工作总结
2014/11/19 职场文书
涪陵白鹤梁导游词
2015/02/09 职场文书
电台广播稿范文
2015/08/19 职场文书
七年级上册生物的课件
2019/08/07 职场文书
PhpSpreadsheet中文文档 | Spreadsheet操作教程实例
2021/04/01 PHP
mybatis 获取更新记录的id
2022/05/20 Java/Android