python代码实现逻辑回归logistic原理


Posted in Python onAugust 07, 2019

Logistic Regression Classifier逻辑回归主要思想就是用最大似然概率方法构建出方程,为最大化方程,利用牛顿梯度上升求解方程参数。

  • 优点:计算代价不高,易于理解和实现。
  • 缺点:容易欠拟合,分类精度可能不高。
  • 使用数据类型:数值型和标称型数据。

介绍逻辑回归之前,我们先看一问题,有个黑箱,里面有白球和黑球,如何判断它们的比例。

我们从里面抓3个球,2个黑球,1个白球。这时候,有人就直接得出了黑球67%,白球占比33%。这个时候,其实这个人使用了最大似然概率的思想,通俗来讲,当黑球是67%的占比的时候,我们抓3个球,出现2黑1白的概率最大。我们直接用公式来说明。

假设黑球占比为P,白球为1-P。于是我们要求解MAX(PP(1-P)),显而易见P=67%(求解方法:对方程求导,使导数为0的P值即为最优解)

我们看逻辑回归,解决的是二分类问题,是不是和上面黑球白球问题很像,是的,逻辑回归也是最大似然概率来求解。

假设我们有n个独立的训练样本{(x1, y1) ,(x2, y2),…, (xn, yn)},y={0, 1}。那每一个观察到的样本(xi, yi)出现的概率是:

python代码实现逻辑回归logistic原理

上面为什么是这样呢?当y=1的时候,后面那一项是不是没有了,那就只剩下x属于1类的概率,当y=0的时候,第一项是不是没有了,那就只剩下后面那个x属于0的概率(1减去x属于1的概率)。所以不管y是0还是1,上面得到的数,都是(x, y)出现的概率。那我们的整个样本集,也就是n个独立的样本出现的似然函数为(因为每个样本都是独立的,所以n个样本出现的概率就是他们各自出现的概率相乘):

python代码实现逻辑回归logistic原理

这里我们稍微变换下L(θ):取自然对数,然后化简(不要看到一堆公式就害怕哦,很简单的哦,只需要耐心一点点,自己动手推推就知道了。注:有xi的时候,表示它是第i个样本,下面没有做区分了,相信你的眼睛是雪亮的),得到:

python代码实现逻辑回归logistic原理

其中第三步到第四步使用了下面替换。

python代码实现逻辑回归logistic原理

这时候为求最大值,对L(θ)对θ求导,得到:

python代码实现逻辑回归logistic原理

然后我们令该导数为0,即可求出最优解。但是这个方程是无法解析求解(这里就不证明了)。
最后问题变成了,求解参数使方程L最大化,求解参数的方法梯度上升法(原理这里不解释了,看详细的代码的计算方式应该更容易理解些)。

根据这个转换公式

python代码实现逻辑回归logistic原理

我们代入参数和特征,求P,也就是发生1的概率。

python代码实现逻辑回归logistic原理

上面这个也就是常提及的sigmoid函数,俗称激活函数,最后用于分类(若P(y=1|x;Θ\ThetaΘ )大于0.5,则判定为1)。

下面是详细的逻辑回归代码,代码比较简单,主要是要理解上面的算法思想。个人建议,可以结合代码看一步一步怎么算的,然后对比上面推导公式,可以让人更加容易理解,并加深印象。

from numpy import *
filename='...\\testSet.txt' #文件目录
def loadDataSet():  #读取数据(这里只有两个特征)
  dataMat = []
  labelMat = []
  fr = open(filename)
  for line in fr.readlines():
    lineArr = line.strip().split()
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])  #前面的1,表示方程的常量。比如两个特征X1,X2,共需要三个参数,W1+W2*X1+W3*X2
    labelMat.append(int(lineArr[2]))
  return dataMat,labelMat

def sigmoid(inX): #sigmoid函数
  return 1.0/(1+exp(-inX))

def gradAscent(dataMat, labelMat): #梯度上升求最优参数
  dataMatrix=mat(dataMat) #将读取的数据转换为矩阵
  classLabels=mat(labelMat).transpose() #将读取的数据转换为矩阵
  m,n = shape(dataMatrix)
  alpha = 0.001 #设置梯度的阀值,该值越大梯度上升幅度越大
  maxCycles = 500 #设置迭代的次数,一般看实际数据进行设定,有些可能200次就够了
  weights = ones((n,1)) #设置初始的参数,并都赋默认值为1。注意这里权重以矩阵形式表示三个参数。
  for k in range(maxCycles):
    h = sigmoid(dataMatrix*weights)
    error = (classLabels - h)   #求导后差值
    weights = weights + alpha * dataMatrix.transpose()* error #迭代更新权重
  return weights

def stocGradAscent0(dataMat, labelMat): #随机梯度上升,当数据量比较大时,每次迭代都选择全量数据进行计算,计算量会非常大。所以采用每次迭代中一次只选择其中的一行数据进行更新权重。
  dataMatrix=mat(dataMat)
  classLabels=labelMat
  m,n=shape(dataMatrix)
  alpha=0.01
  maxCycles = 500
  weights=ones((n,1))
  for k in range(maxCycles):
    for i in range(m): #遍历计算每一行
      h = sigmoid(sum(dataMatrix[i] * weights))
      error = classLabels[i] - h
      weights = weights + alpha * error * dataMatrix[i].transpose()
  return weights

def stocGradAscent1(dataMat, labelMat): #改进版随机梯度上升,在每次迭代中随机选择样本来更新权重,并且随迭代次数增加,权重变化越小。
  dataMatrix=mat(dataMat)
  classLabels=labelMat
  m,n=shape(dataMatrix)
  weights=ones((n,1))
  maxCycles=500
  for j in range(maxCycles): #迭代
    dataIndex=[i for i in range(m)]
    for i in range(m): #随机遍历每一行
      alpha=4/(1+j+i)+0.0001 #随迭代次数增加,权重变化越小。
      randIndex=int(random.uniform(0,len(dataIndex))) #随机抽样
      h=sigmoid(sum(dataMatrix[randIndex]*weights))
      error=classLabels[randIndex]-h
      weights=weights+alpha*error*dataMatrix[randIndex].transpose()
      del(dataIndex[randIndex]) #去除已经抽取的样本
  return weights

def plotBestFit(weights): #画出最终分类的图
  import matplotlib.pyplot as plt
  dataMat,labelMat=loadDataSet()
  dataArr = array(dataMat)
  n = shape(dataArr)[0]
  xcord1 = []; ycord1 = []
  xcord2 = []; ycord2 = []
  for i in range(n):
    if int(labelMat[i])== 1:
      xcord1.append(dataArr[i,1])
      ycord1.append(dataArr[i,2])
    else:
      xcord2.append(dataArr[i,1])
      ycord2.append(dataArr[i,2])
  fig = plt.figure()
  ax = fig.add_subplot(111)
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
  ax.scatter(xcord2, ycord2, s=30, c='green')
  x = arange(-3.0, 3.0, 0.1)
  y = (-weights[0]-weights[1]*x)/weights[2]
  ax.plot(x, y)
  plt.xlabel('X1')
  plt.ylabel('X2')
  plt.show()

def main():
  dataMat, labelMat = loadDataSet()
  weights=gradAscent(dataMat, labelMat).getA()
  plotBestFit(weights)

if __name__=='__main__':
  main()

跑完代码结果:

python代码实现逻辑回归logistic原理

当然,还可以换随机梯度上升和改进的随机梯度上升算法试试,效果都还不错。

下面是代码使用的数据,可以直接复制本地text里面,跑上面代码。

-0.017612	14.053064	0
-1.395634	4.662541	1
-0.752157	6.538620	0
-1.322371	7.152853	0
0.423363	11.054677	0
0.406704	7.067335	1
0.667394	12.741452	0
-2.460150	6.866805	1
0.569411	9.548755	0
-0.026632	10.427743	0
0.850433	6.920334	1
1.347183	13.175500	0
1.176813	3.167020	1
-1.781871	9.097953	0
-0.566606	5.749003	1
0.931635	1.589505	1
-0.024205	6.151823	1
-0.036453	2.690988	1
-0.196949	0.444165	1
1.014459	5.754399	1
1.985298	3.230619	1
-1.693453	-0.557540	1
-0.576525	11.778922	0
-0.346811	-1.678730	1
-2.124484	2.672471	1
1.217916	9.597015	0
-0.733928	9.098687	0
-3.642001	-1.618087	1
0.315985	3.523953	1
1.416614	9.619232	0
-0.386323	3.989286	1
0.556921	8.294984	1
1.224863	11.587360	0
-1.347803	-2.406051	1
1.196604	4.951851	1
0.275221	9.543647	0
0.470575	9.332488	0
-1.889567	9.542662	0
-1.527893	12.150579	0
-1.185247	11.309318	0
-0.445678	3.297303	1
1.042222	6.105155	1
-0.618787	10.320986	0
1.152083	0.548467	1
0.828534	2.676045	1
-1.237728	10.549033	0
-0.683565	-2.166125	1
0.229456	5.921938	1
-0.959885	11.555336	0
0.492911	10.993324	0
0.184992	8.721488	0
-0.355715	10.325976	0
-0.397822	8.058397	0
0.824839	13.730343	0
1.507278	5.027866	1
0.099671	6.835839	1
-0.344008	10.717485	0
1.785928	7.718645	1
-0.918801	11.560217	0
-0.364009	4.747300	1
-0.841722	4.119083	1
0.490426	1.960539	1
-0.007194	9.075792	0
0.356107	12.447863	0
0.342578	12.281162	0
-0.810823	-1.466018	1
2.530777	6.476801	1
1.296683	11.607559	0
0.475487	12.040035	0
-0.783277	11.009725	0
0.074798	11.023650	0
-1.337472	0.468339	1
-0.102781	13.763651	0
-0.147324	2.874846	1
0.518389	9.887035	0
1.015399	7.571882	0
-1.658086	-0.027255	1
1.319944	2.171228	1
2.056216	5.019981	1
-0.851633	4.375691	1
-1.510047	6.061992	0
-1.076637	-3.181888	1
1.821096	10.283990	0
3.010150	8.401766	1
-1.099458	1.688274	1
-0.834872	-1.733869	1
-0.846637	3.849075	1
1.400102	12.628781	0
1.752842	5.468166	1
0.078557	0.059736	1
0.089392	-0.715300	1
1.825662	12.693808	0
0.197445	9.744638	0
0.126117	0.922311	1
-0.679797	1.220530	1
0.677983	2.556666	1
0.761349	10.693862	0
-2.168791	0.143632	1
1.388610	9.341997	0
0.317029	14.739025	0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用rsa加密算法模块模拟新浪微博登录
Jan 22 Python
Python 制作糗事百科爬虫实例
Sep 22 Python
Django权限机制实现代码详解
Feb 05 Python
python3下使用cv2.imwrite存储带有中文路径图片的方法
May 10 Python
win7 x64系统中安装Scrapy的方法
Nov 18 Python
Python + OpenCV 实现LBP特征提取的示例代码
Jul 11 Python
flask实现验证码并验证功能
Dec 05 Python
Python中if有多个条件处理方法
Feb 26 Python
如何理解Python中包的引入
May 29 Python
python中如何写类
Jun 29 Python
pyqt5实现井字棋的示例代码
Dec 07 Python
python实现excel公式格式化的示例代码
Dec 23 Python
Python在cmd上打印彩色文字实现过程详解
Aug 07 #Python
Python如何调用外部系统命令
Aug 07 #Python
PyQt5通信机制 信号与槽详解
Aug 07 #Python
python 使用socket传输图片视频等文件的实现方式
Aug 07 #Python
python获取Pandas列名的几种方法
Aug 07 #Python
python 提取文件指定列的方法示例
Aug 07 #Python
PyQt Qt Designer工具的布局管理详解
Aug 07 #Python
You might like
解析PHP中的file_get_contents获取远程页面乱码的问题
2013/06/25 PHP
php读取csv文件后,uft8 bom导致在页面上显示出现问题的解决方法
2013/08/10 PHP
PHP实现图片上传并压缩
2015/12/22 PHP
PJBlog插件 防刷新的在线播放器
2006/10/25 Javascript
javascript 数组排序函数
2009/08/20 Javascript
js 格式化时间日期函数小结
2010/03/20 Javascript
jquery ajax post提交数据乱码
2013/11/05 Javascript
js+jquery常用知识点汇总
2015/03/03 Javascript
JavaScript使用Prototype实现面向对象的方法
2015/04/14 Javascript
js模拟淘宝网的多级选择菜单实现方法
2015/08/18 Javascript
基于javascript bootstrap实现生日日期联动选择
2016/04/07 Javascript
JavaScript实现相册弹窗功能(zepto.js)
2016/06/21 Javascript
Vuejs第十二篇之动态组件全面解析
2016/09/09 Javascript
jQuery实现 上升、下降、删除、添加一行代码
2017/03/06 Javascript
Ionic + Angular.js实现验证码倒计时功能的方法
2017/06/12 Javascript
node.js自动上传ftp的脚本分享
2018/06/16 Javascript
vue项目base64字符串转图片的实现代码
2018/07/13 Javascript
vue项目中使用vue-i18n报错的解决方法
2019/01/13 Javascript
node.js的http.createServer过程深入解析
2019/06/06 Javascript
解决vue更新路由router-view复用组件内容不刷新的问题
2019/11/04 Javascript
javascript 数组精简技巧小结
2020/02/26 Javascript
bootstrap-table后端分页功能完整实例
2020/06/01 Javascript
在vue中实现清除echarts上次保留的数据(亲测有效)
2020/09/09 Javascript
[00:18]天涯墨客三技能展示
2018/08/25 DOTA
对python产生随机的二维数组实例详解
2018/12/13 Python
Python设置matplotlib.plot的坐标轴刻度间隔以及刻度范围
2019/06/25 Python
python实现两个dict合并与计算操作示例
2019/07/01 Python
Pytorch在NLP中的简单应用详解
2020/01/08 Python
Python脚本实现监听服务器的思路代码详解
2020/05/28 Python
python 负数取模运算实例
2020/06/03 Python
Python高并发和多线程有什么关系
2020/11/14 Python
详解WebSocket跨域问题解决
2018/08/06 HTML / CSS
预订全球最佳旅行体验:Viator
2018/03/30 全球购物
初一家长会邀请函
2014/01/31 职场文书
计算机科学技术自荐信
2014/06/12 职场文书
2015年国际护士节演讲稿
2015/03/18 职场文书