机器学习经典算法-logistic回归代码详解


Posted in Python onDecember 22, 2017

一、算法简要

我们希望有这么一种函数:接受输入然后预测出类别,这样用于分类。这里,用到了数学中的sigmoid函数,sigmoid函数的具体表达式和函数图象如下:

机器学习经典算法-logistic回归代码详解

可以较为清楚的看到,当输入的x小于0时,函数值<0.5,将分类预测为0;当输入的x大于0时,函数值>0.5,将分类预测为1。

1.1 预测函数的表示

机器学习经典算法-logistic回归代码详解

1.2参数的求解

机器学习经典算法-logistic回归代码详解

二、代码实现

函数sigmoid计算相应的函数值;gradAscent实现的batch-梯度上升,意思就是在每次迭代中所有数据集都考虑到了;而stoGradAscent0中,则是将数据集中的示例都比那里了一遍,复杂度大大降低;stoGradAscent1则是对随机梯度上升的改进,具体变化是alpha每次变化的频率是变化的,而且每次更新参数用到的示例都是随机选取的。

from numpy import * 
import matplotlib.pyplot as plt 
def loadDataSet(): 
  dataMat = [] 
  labelMat = [] 
  fr = open('testSet.txt') 
  for line in fr.readlines(): 
    lineArr = line.strip('\n').split('\t') 
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])]) 
    labelMat.append(int(lineArr[2])) 
  fr.close() 
  return dataMat, labelMat 
def sigmoid(inX): 
  return 1.0/(1+exp(-inX)) 
def gradAscent(dataMatIn, classLabels): 
  dataMatrix = mat(dataMatIn) 
  labelMat = mat(classLabels).transpose() 
  m,n=shape(dataMatrix) 
  alpha = 0.001 
  maxCycles = 500 
  weights = ones((n,1)) 
  errors=[] 
  for k in range(maxCycles): 
    h = sigmoid(dataMatrix*weights) 
    error = labelMat - h 
    errors.append(sum(error)) 
    weights = weights + alpha*dataMatrix.transpose()*error 
  return weights, errors 
def stoGradAscent0(dataMatIn, classLabels): 
  m,n=shape(dataMatIn) 
  alpha = 0.01 
  weights = ones(n) 
  for i in range(m): 
    h = sigmoid(sum(dataMatIn[i]*weights)) 
    error = classLabels[i] - h  
    weights = weights + alpha*error*dataMatIn[i] 
  return weights 
def stoGradAscent1(dataMatrix, classLabels, numIter = 150): 
  m,n=shape(dataMatrix) 
  weights = ones(n) 
  for j in range(numIter): 
    dataIndex=range(m) 
    for i in range(m): 
      alpha= 4/(1.0+j+i)+0.01 
      randIndex = int(random.uniform(0,len(dataIndex))) 
      h = sigmoid(sum(dataMatrix[randIndex]*weights)) 
      error = classLabels[randIndex]-h 
      weights=weights+alpha*error*dataMatrix[randIndex] 
      del(dataIndex[randIndex]) 
    return weights 
def plotError(errs): 
  k = len(errs) 
  x = range(1,k+1) 
  plt.plot(x,errs,'g--') 
  plt.show() 
def plotBestFit(wei): 
  weights = wei.getA() 
  dataMat, labelMat = loadDataSet() 
  dataArr = array(dataMat) 
  n = shape(dataArr)[0] 
  xcord1=[] 
  ycord1=[] 
  xcord2=[] 
  ycord2=[] 
  for i in range(n):  
    if int(labelMat[i])==1: 
      xcord1.append(dataArr[i,1]) 
      ycord1.append(dataArr[i,2]) 
    else: 
      xcord2.append(dataArr[i,1]) 
      ycord2.append(dataArr[i,2]) 
  fig = plt.figure() 
  ax = fig.add_subplot(111) 
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s') 
  ax.scatter(xcord2, ycord2, s=30, c='green') 
  x = arange(-3.0,3.0,0.1) 
  y=(-weights[0]-weights[1]*x)/weights[2] 
  ax.plot(x,y) 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  plt.show() 
def classifyVector(inX, weights): 
  prob = sigmoid(sum(inX*weights)) 
  if prob>0.5: 
    return 1.0 
  else: 
    return 0 
def colicTest(ftr, fte, numIter): 
  frTrain = open(ftr) 
  frTest = open(fte) 
  trainingSet=[] 
  trainingLabels=[] 
  for line in frTrain.readlines(): 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    trainingSet.append(lineArr) 
    trainingLabels.append(float(currLine[21])) 
  frTrain.close() 
  trainWeights = stoGradAscent1(array(trainingSet),trainingLabels, numIter) 
  errorCount = 0 
  numTestVec = 0.0 
  for line in frTest.readlines(): 
    numTestVec += 1.0 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    if int(classifyVector(array(lineArr), trainWeights))!=int(currLine[21]): 
      errorCount += 1 
  frTest.close() 
  errorRate = (float(errorCount))/numTestVec 
  return errorRate 
def multiTest(ftr, fte, numT, numIter): 
  errors=[] 
  for k in range(numT): 
    error = colicTest(ftr, fte, numIter) 
    errors.append(error) 
  print "There "+str(len(errors))+" test with "+str(numIter)+" interations in all!" 
  for i in range(numT): 
    print "The "+str(i+1)+"th"+" testError is:"+str(errors[i]) 
  print "Average testError: ", float(sum(errors))/len(errors) 
''''' 
data, labels = loadDataSet() 
weights0 = stoGradAscent0(array(data), labels) 
weights,errors = gradAscent(data, labels) 
weights1= stoGradAscent1(array(data), labels, 500) 
print weights 
plotBestFit(weights) 
print weights0 
weights00 = [] 
for w in weights0: 
  weights00.append([w]) 
plotBestFit(mat(weights00)) 
print weights1 
weights11=[] 
for w in weights1: 
  weights11.append([w]) 
plotBestFit(mat(weights11)) 
''' 
multiTest(r"horseColicTraining.txt",r"horseColicTest.txt",10,500)

总结

以上就是本文关于机器学习经典算法-logistic回归代码详解的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站:

如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

Python 相关文章推荐
py2exe 编译ico图标的代码
Mar 08 Python
Python3通过Luhn算法快速验证信用卡卡号的方法
May 14 Python
python搭建虚拟环境的步骤详解
Sep 27 Python
Python生成器以及应用实例解析
Feb 08 Python
Python异常处理操作实例详解
May 10 Python
Pandas读取MySQL数据到DataFrame的方法
Jul 25 Python
python itchat给指定联系人发消息的方法
Jun 11 Python
Python2和3字符编码的区别知识点整理
Aug 08 Python
利用matplotlib实现根据实时数据动态更新图形
Dec 13 Python
利用python3 的pygame模块实现塔防游戏
Dec 30 Python
基于Tensorflow读取MNIST数据集时网络超时的解决方式
Jun 22 Python
Python self用法详解
Nov 28 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 #Python
python实现数据预处理之填充缺失值的示例
Dec 22 #Python
NetworkX之Prim算法(实例讲解)
Dec 22 #Python
Python实现控制台中的进度条功能代码
Dec 22 #Python
Python中的探索性数据分析(功能式)
Dec 22 #Python
Python反射用法实例简析
Dec 22 #Python
Python文本特征抽取与向量化算法学习
Dec 22 #Python
You might like
php获得文件扩展名三法
2006/11/25 PHP
php at(@)符号的用法简介
2009/07/11 PHP
php函数array_merge用法一例(合并同类数组)
2013/02/03 PHP
测试PHP连接MYSQL成功与否的代码
2013/08/16 PHP
php文件上传后端处理小技巧
2016/05/22 PHP
[原创]php简单隔行变色功能实现代码
2016/07/09 PHP
php实现往pdf中加数字签名操作示例【附源码下载】
2018/08/07 PHP
javascript 面向对象,实现namespace,class,继承,重载
2009/10/29 Javascript
jQuery 1.4 15个你应该知道的新特性(译)
2010/01/24 Javascript
你必须知道的Javascript知识点之&quot;this指针&quot;的应用
2013/04/23 Javascript
使用百度地图api实现根据地址查询经纬度
2014/12/11 Javascript
js面向对象之静态方法和静态属性实例分析
2015/01/10 Javascript
Javascript 高阶函数使用介绍
2015/06/15 Javascript
在JavaScript中访问字符串的子串
2015/07/07 Javascript
javascript省市级联功能实现方法实例详解
2015/10/20 Javascript
在Node.js中使用Javascript Generators详解
2016/05/05 Javascript
BootStrap智能表单demo示例详解
2016/06/13 Javascript
你知道setTimeout是如何运行的吗?
2016/08/16 Javascript
js判断手机系统是android还是ios
2017/03/07 Javascript
json的结构与遍历方法实例分析
2017/04/25 Javascript
JavaScript中立即执行函数实例详解
2017/11/04 Javascript
微信小程序自定义导航教程(兼容各种手机)
2018/12/12 Javascript
如何利用vue+vue-router+elementUI实现简易通讯录
2019/05/13 Javascript
微信小程序webview 脚手架使用详解
2019/07/22 Javascript
vue基本使用--refs获取组件或元素的实例
2019/11/07 Javascript
mpvue实现微信小程序快递单号查询代码
2020/04/03 Javascript
[01:01:52]DOTA2-DPC中国联赛正赛 iG vs LBZS BO3 第一场 3月4日
2021/03/11 DOTA
python制作最美应用的爬虫
2015/10/28 Python
Python实现的堆排序算法示例
2018/04/29 Python
python3读取csv和xlsx文件的实例
2018/06/22 Python
Python如何在单元测试中给对象打补丁
2020/08/03 Python
波兰电子产品购物网站:Vobis
2019/05/26 全球购物
俄罗斯奢侈品牌衣服、鞋子和配饰的在线商店:INTERMODA
2020/07/17 全球购物
自我推荐信怎么写
2015/03/24 职场文书
2016廉洁教育心得体会
2016/01/20 职场文书
SQLServer 错误: 15404,无法获取有关 Windows NT 组/用户 WIN-8IVSNAQS8T7\Administrator 的信息
2021/06/30 SQL Server