机器学习经典算法-logistic回归代码详解


Posted in Python onDecember 22, 2017

一、算法简要

我们希望有这么一种函数:接受输入然后预测出类别,这样用于分类。这里,用到了数学中的sigmoid函数,sigmoid函数的具体表达式和函数图象如下:

机器学习经典算法-logistic回归代码详解

可以较为清楚的看到,当输入的x小于0时,函数值<0.5,将分类预测为0;当输入的x大于0时,函数值>0.5,将分类预测为1。

1.1 预测函数的表示

机器学习经典算法-logistic回归代码详解

1.2参数的求解

机器学习经典算法-logistic回归代码详解

二、代码实现

函数sigmoid计算相应的函数值;gradAscent实现的batch-梯度上升,意思就是在每次迭代中所有数据集都考虑到了;而stoGradAscent0中,则是将数据集中的示例都比那里了一遍,复杂度大大降低;stoGradAscent1则是对随机梯度上升的改进,具体变化是alpha每次变化的频率是变化的,而且每次更新参数用到的示例都是随机选取的。

from numpy import * 
import matplotlib.pyplot as plt 
def loadDataSet(): 
  dataMat = [] 
  labelMat = [] 
  fr = open('testSet.txt') 
  for line in fr.readlines(): 
    lineArr = line.strip('\n').split('\t') 
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])]) 
    labelMat.append(int(lineArr[2])) 
  fr.close() 
  return dataMat, labelMat 
def sigmoid(inX): 
  return 1.0/(1+exp(-inX)) 
def gradAscent(dataMatIn, classLabels): 
  dataMatrix = mat(dataMatIn) 
  labelMat = mat(classLabels).transpose() 
  m,n=shape(dataMatrix) 
  alpha = 0.001 
  maxCycles = 500 
  weights = ones((n,1)) 
  errors=[] 
  for k in range(maxCycles): 
    h = sigmoid(dataMatrix*weights) 
    error = labelMat - h 
    errors.append(sum(error)) 
    weights = weights + alpha*dataMatrix.transpose()*error 
  return weights, errors 
def stoGradAscent0(dataMatIn, classLabels): 
  m,n=shape(dataMatIn) 
  alpha = 0.01 
  weights = ones(n) 
  for i in range(m): 
    h = sigmoid(sum(dataMatIn[i]*weights)) 
    error = classLabels[i] - h  
    weights = weights + alpha*error*dataMatIn[i] 
  return weights 
def stoGradAscent1(dataMatrix, classLabels, numIter = 150): 
  m,n=shape(dataMatrix) 
  weights = ones(n) 
  for j in range(numIter): 
    dataIndex=range(m) 
    for i in range(m): 
      alpha= 4/(1.0+j+i)+0.01 
      randIndex = int(random.uniform(0,len(dataIndex))) 
      h = sigmoid(sum(dataMatrix[randIndex]*weights)) 
      error = classLabels[randIndex]-h 
      weights=weights+alpha*error*dataMatrix[randIndex] 
      del(dataIndex[randIndex]) 
    return weights 
def plotError(errs): 
  k = len(errs) 
  x = range(1,k+1) 
  plt.plot(x,errs,'g--') 
  plt.show() 
def plotBestFit(wei): 
  weights = wei.getA() 
  dataMat, labelMat = loadDataSet() 
  dataArr = array(dataMat) 
  n = shape(dataArr)[0] 
  xcord1=[] 
  ycord1=[] 
  xcord2=[] 
  ycord2=[] 
  for i in range(n):  
    if int(labelMat[i])==1: 
      xcord1.append(dataArr[i,1]) 
      ycord1.append(dataArr[i,2]) 
    else: 
      xcord2.append(dataArr[i,1]) 
      ycord2.append(dataArr[i,2]) 
  fig = plt.figure() 
  ax = fig.add_subplot(111) 
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s') 
  ax.scatter(xcord2, ycord2, s=30, c='green') 
  x = arange(-3.0,3.0,0.1) 
  y=(-weights[0]-weights[1]*x)/weights[2] 
  ax.plot(x,y) 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  plt.show() 
def classifyVector(inX, weights): 
  prob = sigmoid(sum(inX*weights)) 
  if prob>0.5: 
    return 1.0 
  else: 
    return 0 
def colicTest(ftr, fte, numIter): 
  frTrain = open(ftr) 
  frTest = open(fte) 
  trainingSet=[] 
  trainingLabels=[] 
  for line in frTrain.readlines(): 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    trainingSet.append(lineArr) 
    trainingLabels.append(float(currLine[21])) 
  frTrain.close() 
  trainWeights = stoGradAscent1(array(trainingSet),trainingLabels, numIter) 
  errorCount = 0 
  numTestVec = 0.0 
  for line in frTest.readlines(): 
    numTestVec += 1.0 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    if int(classifyVector(array(lineArr), trainWeights))!=int(currLine[21]): 
      errorCount += 1 
  frTest.close() 
  errorRate = (float(errorCount))/numTestVec 
  return errorRate 
def multiTest(ftr, fte, numT, numIter): 
  errors=[] 
  for k in range(numT): 
    error = colicTest(ftr, fte, numIter) 
    errors.append(error) 
  print "There "+str(len(errors))+" test with "+str(numIter)+" interations in all!" 
  for i in range(numT): 
    print "The "+str(i+1)+"th"+" testError is:"+str(errors[i]) 
  print "Average testError: ", float(sum(errors))/len(errors) 
''''' 
data, labels = loadDataSet() 
weights0 = stoGradAscent0(array(data), labels) 
weights,errors = gradAscent(data, labels) 
weights1= stoGradAscent1(array(data), labels, 500) 
print weights 
plotBestFit(weights) 
print weights0 
weights00 = [] 
for w in weights0: 
  weights00.append([w]) 
plotBestFit(mat(weights00)) 
print weights1 
weights11=[] 
for w in weights1: 
  weights11.append([w]) 
plotBestFit(mat(weights11)) 
''' 
multiTest(r"horseColicTraining.txt",r"horseColicTest.txt",10,500)

总结

以上就是本文关于机器学习经典算法-logistic回归代码详解的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站:

如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

Python 相关文章推荐
python使用7z解压软件备份文件脚本分享
Feb 21 Python
深入浅析python定时杀进程
Jun 06 Python
tensorflow训练中出现nan问题的解决
Feb 10 Python
使用pandas read_table读取csv文件的方法
Jul 04 Python
Python实现string字符串连接的方法总结【8种方式】
Jul 06 Python
python 返回一个列表中第二大的数方法
Jul 09 Python
Python  Django 母版和继承解析
Aug 09 Python
Python单例模式的四种创建方式实例解析
Mar 04 Python
python实现简单俄罗斯方块
Mar 13 Python
Django模板获取field的verbose_name实例
May 19 Python
如何一键升级Python所有包
Nov 05 Python
PyQt5中QSpinBox计数器的实现
Jan 18 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 #Python
python实现数据预处理之填充缺失值的示例
Dec 22 #Python
NetworkX之Prim算法(实例讲解)
Dec 22 #Python
Python实现控制台中的进度条功能代码
Dec 22 #Python
Python中的探索性数据分析(功能式)
Dec 22 #Python
Python反射用法实例简析
Dec 22 #Python
Python文本特征抽取与向量化算法学习
Dec 22 #Python
You might like
求PHP数组最大值,最小值的代码
2011/10/31 PHP
php实现图片上传并进行替换操作
2016/03/15 PHP
smarty循环嵌套用法示例分析
2016/07/19 PHP
php单元测试phpunit入门实例教程
2017/11/17 PHP
php微信公众号开发之音乐信息
2018/10/20 PHP
php curl简单采集图片生成base64编码(并附curl函数参数说明)
2019/02/15 PHP
JavaScript调用Activex控件的事件的实现方法
2010/04/11 Javascript
$.ajax返回的JSON无法执行success的解决方法
2011/09/09 Javascript
jquery封装的对话框简单实现
2013/07/21 Javascript
javascript右下角弹层及自动隐藏(自己编写)
2013/11/20 Javascript
jquery复选框checkbox实现删除前判断
2014/04/20 Javascript
js生成缩略图后上传并利用canvas重绘
2014/05/15 Javascript
浅谈JavaScript中的String对象常用方法
2015/02/25 Javascript
javascript实现用户点击数量统计
2016/12/25 Javascript
利用JavaScript实现栈的数据结构示例代码
2017/08/02 Javascript
使用vue-cli导入Element UI组件的方法
2018/05/16 Javascript
对angularjs框架下controller间的传值方法详解
2018/10/08 Javascript
vue实现跨域的方法分析
2019/05/21 Javascript
JavaScript如何实现图片处理与合成
2020/05/29 Javascript
[01:02:18]VGJ.S vs infamous Supermajor 败者组 BO3 第一场 6.4
2018/06/05 DOTA
Python中捕捉详细异常信息的代码示例
2014/09/18 Python
Python错误: SyntaxError: Non-ASCII character解决办法
2017/06/08 Python
Python3 实现随机生成一组不重复数并按行写入文件
2018/04/09 Python
基于Python中求和函数sum的用法详解
2018/06/28 Python
通过pykafka接收Kafka消息队列的方法
2018/12/27 Python
使用Python 统计高频字数的方法
2019/01/31 Python
PyTorch中常用的激活函数的方法示例
2019/08/20 Python
Python定时任务框架APScheduler原理及常用代码
2020/10/05 Python
CSS3弹性盒模型开发笔记(三)
2016/04/26 HTML / CSS
详解HTML5 Canvas绘制不规则图形时的非零环绕原则
2016/03/21 HTML / CSS
英国领先的隐形眼镜在线供应商:Lenstore.co.uk
2019/11/24 全球购物
一个C/C++编程面试题
2013/11/10 面试题
涨价通知怎么写
2015/04/23 职场文书
公司考勤管理制度
2015/08/04 职场文书
企业年会祝酒词
2015/08/11 职场文书
公司董事任命书
2015/09/21 职场文书