机器学习经典算法-logistic回归代码详解


Posted in Python onDecember 22, 2017

一、算法简要

我们希望有这么一种函数:接受输入然后预测出类别,这样用于分类。这里,用到了数学中的sigmoid函数,sigmoid函数的具体表达式和函数图象如下:

机器学习经典算法-logistic回归代码详解

可以较为清楚的看到,当输入的x小于0时,函数值<0.5,将分类预测为0;当输入的x大于0时,函数值>0.5,将分类预测为1。

1.1 预测函数的表示

机器学习经典算法-logistic回归代码详解

1.2参数的求解

机器学习经典算法-logistic回归代码详解

二、代码实现

函数sigmoid计算相应的函数值;gradAscent实现的batch-梯度上升,意思就是在每次迭代中所有数据集都考虑到了;而stoGradAscent0中,则是将数据集中的示例都比那里了一遍,复杂度大大降低;stoGradAscent1则是对随机梯度上升的改进,具体变化是alpha每次变化的频率是变化的,而且每次更新参数用到的示例都是随机选取的。

from numpy import * 
import matplotlib.pyplot as plt 
def loadDataSet(): 
  dataMat = [] 
  labelMat = [] 
  fr = open('testSet.txt') 
  for line in fr.readlines(): 
    lineArr = line.strip('\n').split('\t') 
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])]) 
    labelMat.append(int(lineArr[2])) 
  fr.close() 
  return dataMat, labelMat 
def sigmoid(inX): 
  return 1.0/(1+exp(-inX)) 
def gradAscent(dataMatIn, classLabels): 
  dataMatrix = mat(dataMatIn) 
  labelMat = mat(classLabels).transpose() 
  m,n=shape(dataMatrix) 
  alpha = 0.001 
  maxCycles = 500 
  weights = ones((n,1)) 
  errors=[] 
  for k in range(maxCycles): 
    h = sigmoid(dataMatrix*weights) 
    error = labelMat - h 
    errors.append(sum(error)) 
    weights = weights + alpha*dataMatrix.transpose()*error 
  return weights, errors 
def stoGradAscent0(dataMatIn, classLabels): 
  m,n=shape(dataMatIn) 
  alpha = 0.01 
  weights = ones(n) 
  for i in range(m): 
    h = sigmoid(sum(dataMatIn[i]*weights)) 
    error = classLabels[i] - h  
    weights = weights + alpha*error*dataMatIn[i] 
  return weights 
def stoGradAscent1(dataMatrix, classLabels, numIter = 150): 
  m,n=shape(dataMatrix) 
  weights = ones(n) 
  for j in range(numIter): 
    dataIndex=range(m) 
    for i in range(m): 
      alpha= 4/(1.0+j+i)+0.01 
      randIndex = int(random.uniform(0,len(dataIndex))) 
      h = sigmoid(sum(dataMatrix[randIndex]*weights)) 
      error = classLabels[randIndex]-h 
      weights=weights+alpha*error*dataMatrix[randIndex] 
      del(dataIndex[randIndex]) 
    return weights 
def plotError(errs): 
  k = len(errs) 
  x = range(1,k+1) 
  plt.plot(x,errs,'g--') 
  plt.show() 
def plotBestFit(wei): 
  weights = wei.getA() 
  dataMat, labelMat = loadDataSet() 
  dataArr = array(dataMat) 
  n = shape(dataArr)[0] 
  xcord1=[] 
  ycord1=[] 
  xcord2=[] 
  ycord2=[] 
  for i in range(n):  
    if int(labelMat[i])==1: 
      xcord1.append(dataArr[i,1]) 
      ycord1.append(dataArr[i,2]) 
    else: 
      xcord2.append(dataArr[i,1]) 
      ycord2.append(dataArr[i,2]) 
  fig = plt.figure() 
  ax = fig.add_subplot(111) 
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s') 
  ax.scatter(xcord2, ycord2, s=30, c='green') 
  x = arange(-3.0,3.0,0.1) 
  y=(-weights[0]-weights[1]*x)/weights[2] 
  ax.plot(x,y) 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  plt.show() 
def classifyVector(inX, weights): 
  prob = sigmoid(sum(inX*weights)) 
  if prob>0.5: 
    return 1.0 
  else: 
    return 0 
def colicTest(ftr, fte, numIter): 
  frTrain = open(ftr) 
  frTest = open(fte) 
  trainingSet=[] 
  trainingLabels=[] 
  for line in frTrain.readlines(): 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    trainingSet.append(lineArr) 
    trainingLabels.append(float(currLine[21])) 
  frTrain.close() 
  trainWeights = stoGradAscent1(array(trainingSet),trainingLabels, numIter) 
  errorCount = 0 
  numTestVec = 0.0 
  for line in frTest.readlines(): 
    numTestVec += 1.0 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    if int(classifyVector(array(lineArr), trainWeights))!=int(currLine[21]): 
      errorCount += 1 
  frTest.close() 
  errorRate = (float(errorCount))/numTestVec 
  return errorRate 
def multiTest(ftr, fte, numT, numIter): 
  errors=[] 
  for k in range(numT): 
    error = colicTest(ftr, fte, numIter) 
    errors.append(error) 
  print "There "+str(len(errors))+" test with "+str(numIter)+" interations in all!" 
  for i in range(numT): 
    print "The "+str(i+1)+"th"+" testError is:"+str(errors[i]) 
  print "Average testError: ", float(sum(errors))/len(errors) 
''''' 
data, labels = loadDataSet() 
weights0 = stoGradAscent0(array(data), labels) 
weights,errors = gradAscent(data, labels) 
weights1= stoGradAscent1(array(data), labels, 500) 
print weights 
plotBestFit(weights) 
print weights0 
weights00 = [] 
for w in weights0: 
  weights00.append([w]) 
plotBestFit(mat(weights00)) 
print weights1 
weights11=[] 
for w in weights1: 
  weights11.append([w]) 
plotBestFit(mat(weights11)) 
''' 
multiTest(r"horseColicTraining.txt",r"horseColicTest.txt",10,500)

总结

以上就是本文关于机器学习经典算法-logistic回归代码详解的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站:

如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

Python 相关文章推荐
python创建进程fork用法
Jun 04 Python
使用简单工厂模式来进行Python的设计模式编程
Mar 01 Python
Python网络爬虫实例讲解
Apr 28 Python
python整小时 整天时间戳获取算法示例
Feb 20 Python
python在不同条件下的输入与输出
Feb 13 Python
Python3 assert断言实现原理解析
Mar 02 Python
使用python实现飞机大战游戏
Mar 23 Python
Selenium基于PIL实现拼接滚动截图
Apr 10 Python
python实现俄罗斯方块小游戏
Apr 24 Python
Python如何读取、写入JSON数据
Jul 28 Python
Python自动化测试中yaml文件读取操作
Aug 20 Python
如何利用python正则表达式匹配版本信息
Dec 09 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 #Python
python实现数据预处理之填充缺失值的示例
Dec 22 #Python
NetworkX之Prim算法(实例讲解)
Dec 22 #Python
Python实现控制台中的进度条功能代码
Dec 22 #Python
Python中的探索性数据分析(功能式)
Dec 22 #Python
Python反射用法实例简析
Dec 22 #Python
Python文本特征抽取与向量化算法学习
Dec 22 #Python
You might like
PHP开发文件系统实例讲解
2006/10/09 PHP
php地址引用(php地址引用的效率问题)
2012/03/23 PHP
使用PHP编写的SVN类
2013/07/18 PHP
php+js实现图片的上传、裁剪、预览、提交示例
2013/08/27 PHP
PHP中rename()函数的妙用讲解
2019/02/28 PHP
PHP 代码简洁之道(小结)
2019/10/16 PHP
PHP实现简单的协程任务调度demo示例
2020/02/01 PHP
jQuery取id有.的值的方法
2014/05/21 Javascript
jquery实现的鼠标下拉滚动置顶效果
2014/07/24 Javascript
avascript中的自执行匿名函数应用示例
2014/09/15 Javascript
Thinkphp模板没有解析直接原样输出的解决方法
2014/10/31 Javascript
jQuery $.each遍历对象、数组用法实例
2015/04/16 Javascript
javascript实现的简单计时器
2015/07/19 Javascript
使用jquery获取url以及jquery获取url参数的实现方法
2016/05/25 Javascript
一个仿微博登陆邮箱提示框js开发案例
2016/07/28 Javascript
JavaScript比较当前时间是否在指定时间段内的方法
2016/08/02 Javascript
微信小程序实现下拉刷新和轮播图效果
2017/11/21 Javascript
vue.js添加一些触摸事件以及安装fastclick的实例
2018/08/28 Javascript
vue.js+element 默认提示中英文操作
2020/11/11 Javascript
可用于监控 mysql Master Slave 状态的python代码
2013/02/10 Python
详解使用python的logging模块在stdout输出的两种方法
2017/05/17 Python
浅谈python迭代器
2017/11/08 Python
python实现数据导出到excel的示例--普通格式
2018/05/03 Python
python之mock模块基本使用方法详解
2019/06/27 Python
如何打包Python Web项目实现免安装一键启动的方法
2020/05/21 Python
波兰汽车配件网上商店:iParts.pl
2020/09/08 全球购物
远程调用的原理
2014/07/05 面试题
园林资料员岗位职责
2013/12/30 职场文书
《油菜花开了》教学反思
2014/02/22 职场文书
中学生寄语大全
2014/04/03 职场文书
就业协议书样本
2014/08/20 职场文书
2014年教育实习工作总结
2014/11/22 职场文书
审美与表现自我评价
2015/03/09 职场文书
入党积极分子半年考察意见
2015/06/02 职场文书
2016年优秀共青团员事迹材料
2016/02/25 职场文书
创业计划之特色精品店
2019/08/12 职场文书