机器学习经典算法-logistic回归代码详解


Posted in Python onDecember 22, 2017

一、算法简要

我们希望有这么一种函数:接受输入然后预测出类别,这样用于分类。这里,用到了数学中的sigmoid函数,sigmoid函数的具体表达式和函数图象如下:

机器学习经典算法-logistic回归代码详解

可以较为清楚的看到,当输入的x小于0时,函数值<0.5,将分类预测为0;当输入的x大于0时,函数值>0.5,将分类预测为1。

1.1 预测函数的表示

机器学习经典算法-logistic回归代码详解

1.2参数的求解

机器学习经典算法-logistic回归代码详解

二、代码实现

函数sigmoid计算相应的函数值;gradAscent实现的batch-梯度上升,意思就是在每次迭代中所有数据集都考虑到了;而stoGradAscent0中,则是将数据集中的示例都比那里了一遍,复杂度大大降低;stoGradAscent1则是对随机梯度上升的改进,具体变化是alpha每次变化的频率是变化的,而且每次更新参数用到的示例都是随机选取的。

from numpy import * 
import matplotlib.pyplot as plt 
def loadDataSet(): 
  dataMat = [] 
  labelMat = [] 
  fr = open('testSet.txt') 
  for line in fr.readlines(): 
    lineArr = line.strip('\n').split('\t') 
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])]) 
    labelMat.append(int(lineArr[2])) 
  fr.close() 
  return dataMat, labelMat 
def sigmoid(inX): 
  return 1.0/(1+exp(-inX)) 
def gradAscent(dataMatIn, classLabels): 
  dataMatrix = mat(dataMatIn) 
  labelMat = mat(classLabels).transpose() 
  m,n=shape(dataMatrix) 
  alpha = 0.001 
  maxCycles = 500 
  weights = ones((n,1)) 
  errors=[] 
  for k in range(maxCycles): 
    h = sigmoid(dataMatrix*weights) 
    error = labelMat - h 
    errors.append(sum(error)) 
    weights = weights + alpha*dataMatrix.transpose()*error 
  return weights, errors 
def stoGradAscent0(dataMatIn, classLabels): 
  m,n=shape(dataMatIn) 
  alpha = 0.01 
  weights = ones(n) 
  for i in range(m): 
    h = sigmoid(sum(dataMatIn[i]*weights)) 
    error = classLabels[i] - h  
    weights = weights + alpha*error*dataMatIn[i] 
  return weights 
def stoGradAscent1(dataMatrix, classLabels, numIter = 150): 
  m,n=shape(dataMatrix) 
  weights = ones(n) 
  for j in range(numIter): 
    dataIndex=range(m) 
    for i in range(m): 
      alpha= 4/(1.0+j+i)+0.01 
      randIndex = int(random.uniform(0,len(dataIndex))) 
      h = sigmoid(sum(dataMatrix[randIndex]*weights)) 
      error = classLabels[randIndex]-h 
      weights=weights+alpha*error*dataMatrix[randIndex] 
      del(dataIndex[randIndex]) 
    return weights 
def plotError(errs): 
  k = len(errs) 
  x = range(1,k+1) 
  plt.plot(x,errs,'g--') 
  plt.show() 
def plotBestFit(wei): 
  weights = wei.getA() 
  dataMat, labelMat = loadDataSet() 
  dataArr = array(dataMat) 
  n = shape(dataArr)[0] 
  xcord1=[] 
  ycord1=[] 
  xcord2=[] 
  ycord2=[] 
  for i in range(n):  
    if int(labelMat[i])==1: 
      xcord1.append(dataArr[i,1]) 
      ycord1.append(dataArr[i,2]) 
    else: 
      xcord2.append(dataArr[i,1]) 
      ycord2.append(dataArr[i,2]) 
  fig = plt.figure() 
  ax = fig.add_subplot(111) 
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s') 
  ax.scatter(xcord2, ycord2, s=30, c='green') 
  x = arange(-3.0,3.0,0.1) 
  y=(-weights[0]-weights[1]*x)/weights[2] 
  ax.plot(x,y) 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  plt.show() 
def classifyVector(inX, weights): 
  prob = sigmoid(sum(inX*weights)) 
  if prob>0.5: 
    return 1.0 
  else: 
    return 0 
def colicTest(ftr, fte, numIter): 
  frTrain = open(ftr) 
  frTest = open(fte) 
  trainingSet=[] 
  trainingLabels=[] 
  for line in frTrain.readlines(): 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    trainingSet.append(lineArr) 
    trainingLabels.append(float(currLine[21])) 
  frTrain.close() 
  trainWeights = stoGradAscent1(array(trainingSet),trainingLabels, numIter) 
  errorCount = 0 
  numTestVec = 0.0 
  for line in frTest.readlines(): 
    numTestVec += 1.0 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    if int(classifyVector(array(lineArr), trainWeights))!=int(currLine[21]): 
      errorCount += 1 
  frTest.close() 
  errorRate = (float(errorCount))/numTestVec 
  return errorRate 
def multiTest(ftr, fte, numT, numIter): 
  errors=[] 
  for k in range(numT): 
    error = colicTest(ftr, fte, numIter) 
    errors.append(error) 
  print "There "+str(len(errors))+" test with "+str(numIter)+" interations in all!" 
  for i in range(numT): 
    print "The "+str(i+1)+"th"+" testError is:"+str(errors[i]) 
  print "Average testError: ", float(sum(errors))/len(errors) 
''''' 
data, labels = loadDataSet() 
weights0 = stoGradAscent0(array(data), labels) 
weights,errors = gradAscent(data, labels) 
weights1= stoGradAscent1(array(data), labels, 500) 
print weights 
plotBestFit(weights) 
print weights0 
weights00 = [] 
for w in weights0: 
  weights00.append([w]) 
plotBestFit(mat(weights00)) 
print weights1 
weights11=[] 
for w in weights1: 
  weights11.append([w]) 
plotBestFit(mat(weights11)) 
''' 
multiTest(r"horseColicTraining.txt",r"horseColicTest.txt",10,500)

总结

以上就是本文关于机器学习经典算法-logistic回归代码详解的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站:

如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

Python 相关文章推荐
sqlalchemy对象转dict的示例
Apr 22 Python
python中二维阵列的变换实例
Oct 09 Python
Python实现一个简单的MySQL类
Jan 07 Python
最近Python有点火? 给你7个学习它的理由!
Jun 26 Python
Python实现PS图像调整黑白效果示例
Jan 25 Python
解决python opencv无法显示图片的问题
Oct 28 Python
python 搜索大文件的实例代码
Jul 08 Python
Python读写压缩文件的方法
Jul 30 Python
详解numpy1.19.4与python3.9版本冲突解决
Dec 15 Python
python中K-means算法基础知识点
Jan 25 Python
python中pandas.read_csv()函数的深入讲解
Mar 29 Python
python数字转对应中文的方法总结
Aug 02 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 #Python
python实现数据预处理之填充缺失值的示例
Dec 22 #Python
NetworkX之Prim算法(实例讲解)
Dec 22 #Python
Python实现控制台中的进度条功能代码
Dec 22 #Python
Python中的探索性数据分析(功能式)
Dec 22 #Python
Python反射用法实例简析
Dec 22 #Python
Python文本特征抽取与向量化算法学习
Dec 22 #Python
You might like
解析curl提交GET,POST,Cookie的简单方法
2013/06/29 PHP
destoon文章模块调用企业会员资料的方法
2014/08/22 PHP
ThinkPHP3.2.3实现分页的方法详解
2016/06/03 PHP
php集成动态口令认证
2016/07/21 PHP
PHP图片水印类的封装
2017/07/06 PHP
PHP常用的类封装小结【4个工具类】
2019/06/28 PHP
jQuery 选择器理解
2010/03/16 Javascript
JS控制一个DIV层在指定时间内消失的方法
2014/02/17 Javascript
Javascript基础教程之数据类型 (数值 Number)
2015/01/18 Javascript
常用jQuery代码分享
2015/07/14 Javascript
jQuery无刷新上传之uploadify3.1简单使用
2016/06/18 Javascript
svg动画之动态描边效果
2017/02/22 Javascript
JavaScript 过滤关键字
2017/03/20 Javascript
jQuery轻松实现无缝轮播效果
2017/03/22 jQuery
Bootstrap标签页(Tab)插件切换echarts不显示问题的解决
2018/07/13 Javascript
React 条件渲染最佳实践小结(7种)
2020/09/27 Javascript
[34:39]Secret vs VG 2018国际邀请赛淘汰赛BO3 第二场 8.23
2018/08/24 DOTA
[01:11:21]DOTA2-DPC中国联赛 正赛 Phoenix vs CDEC BO3 第三场 3月7日
2021/03/11 DOTA
Python Tkinter基础控件用法
2014/09/03 Python
在Python中定义和使用抽象类的方法
2016/06/30 Python
python获取url的返回信息方法
2018/12/17 Python
11个Python3字典内置方法大全与示例汇总
2019/05/13 Python
python查看数据类型的方法
2019/10/12 Python
Python一行代码解决矩阵旋转的问题
2019/11/30 Python
使用PyTorch实现MNIST手写体识别代码
2020/01/18 Python
python实现在一个画布上画多个子图
2020/01/19 Python
Django基于客户端下载文件实现方法
2020/04/21 Python
python 还原梯度下降算法实现一维线性回归
2020/10/22 Python
SQL Server里面什么样的视图才能创建索引
2015/04/17 面试题
追悼会子女答谢词
2014/01/28 职场文书
第一批党的群众路线教育实践活动工作总结
2014/03/03 职场文书
五年后的职业生涯规划
2014/03/04 职场文书
销售会议开幕词
2015/01/28 职场文书
律师催款函范文
2015/06/24 职场文书
《天净沙·秋思》教学反思三篇
2019/11/02 职场文书
Go语言空白表示符_的实例用法
2021/07/04 Golang