机器学习经典算法-logistic回归代码详解


Posted in Python onDecember 22, 2017

一、算法简要

我们希望有这么一种函数:接受输入然后预测出类别,这样用于分类。这里,用到了数学中的sigmoid函数,sigmoid函数的具体表达式和函数图象如下:

机器学习经典算法-logistic回归代码详解

可以较为清楚的看到,当输入的x小于0时,函数值<0.5,将分类预测为0;当输入的x大于0时,函数值>0.5,将分类预测为1。

1.1 预测函数的表示

机器学习经典算法-logistic回归代码详解

1.2参数的求解

机器学习经典算法-logistic回归代码详解

二、代码实现

函数sigmoid计算相应的函数值;gradAscent实现的batch-梯度上升,意思就是在每次迭代中所有数据集都考虑到了;而stoGradAscent0中,则是将数据集中的示例都比那里了一遍,复杂度大大降低;stoGradAscent1则是对随机梯度上升的改进,具体变化是alpha每次变化的频率是变化的,而且每次更新参数用到的示例都是随机选取的。

from numpy import * 
import matplotlib.pyplot as plt 
def loadDataSet(): 
  dataMat = [] 
  labelMat = [] 
  fr = open('testSet.txt') 
  for line in fr.readlines(): 
    lineArr = line.strip('\n').split('\t') 
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])]) 
    labelMat.append(int(lineArr[2])) 
  fr.close() 
  return dataMat, labelMat 
def sigmoid(inX): 
  return 1.0/(1+exp(-inX)) 
def gradAscent(dataMatIn, classLabels): 
  dataMatrix = mat(dataMatIn) 
  labelMat = mat(classLabels).transpose() 
  m,n=shape(dataMatrix) 
  alpha = 0.001 
  maxCycles = 500 
  weights = ones((n,1)) 
  errors=[] 
  for k in range(maxCycles): 
    h = sigmoid(dataMatrix*weights) 
    error = labelMat - h 
    errors.append(sum(error)) 
    weights = weights + alpha*dataMatrix.transpose()*error 
  return weights, errors 
def stoGradAscent0(dataMatIn, classLabels): 
  m,n=shape(dataMatIn) 
  alpha = 0.01 
  weights = ones(n) 
  for i in range(m): 
    h = sigmoid(sum(dataMatIn[i]*weights)) 
    error = classLabels[i] - h  
    weights = weights + alpha*error*dataMatIn[i] 
  return weights 
def stoGradAscent1(dataMatrix, classLabels, numIter = 150): 
  m,n=shape(dataMatrix) 
  weights = ones(n) 
  for j in range(numIter): 
    dataIndex=range(m) 
    for i in range(m): 
      alpha= 4/(1.0+j+i)+0.01 
      randIndex = int(random.uniform(0,len(dataIndex))) 
      h = sigmoid(sum(dataMatrix[randIndex]*weights)) 
      error = classLabels[randIndex]-h 
      weights=weights+alpha*error*dataMatrix[randIndex] 
      del(dataIndex[randIndex]) 
    return weights 
def plotError(errs): 
  k = len(errs) 
  x = range(1,k+1) 
  plt.plot(x,errs,'g--') 
  plt.show() 
def plotBestFit(wei): 
  weights = wei.getA() 
  dataMat, labelMat = loadDataSet() 
  dataArr = array(dataMat) 
  n = shape(dataArr)[0] 
  xcord1=[] 
  ycord1=[] 
  xcord2=[] 
  ycord2=[] 
  for i in range(n):  
    if int(labelMat[i])==1: 
      xcord1.append(dataArr[i,1]) 
      ycord1.append(dataArr[i,2]) 
    else: 
      xcord2.append(dataArr[i,1]) 
      ycord2.append(dataArr[i,2]) 
  fig = plt.figure() 
  ax = fig.add_subplot(111) 
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s') 
  ax.scatter(xcord2, ycord2, s=30, c='green') 
  x = arange(-3.0,3.0,0.1) 
  y=(-weights[0]-weights[1]*x)/weights[2] 
  ax.plot(x,y) 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  plt.show() 
def classifyVector(inX, weights): 
  prob = sigmoid(sum(inX*weights)) 
  if prob>0.5: 
    return 1.0 
  else: 
    return 0 
def colicTest(ftr, fte, numIter): 
  frTrain = open(ftr) 
  frTest = open(fte) 
  trainingSet=[] 
  trainingLabels=[] 
  for line in frTrain.readlines(): 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    trainingSet.append(lineArr) 
    trainingLabels.append(float(currLine[21])) 
  frTrain.close() 
  trainWeights = stoGradAscent1(array(trainingSet),trainingLabels, numIter) 
  errorCount = 0 
  numTestVec = 0.0 
  for line in frTest.readlines(): 
    numTestVec += 1.0 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    if int(classifyVector(array(lineArr), trainWeights))!=int(currLine[21]): 
      errorCount += 1 
  frTest.close() 
  errorRate = (float(errorCount))/numTestVec 
  return errorRate 
def multiTest(ftr, fte, numT, numIter): 
  errors=[] 
  for k in range(numT): 
    error = colicTest(ftr, fte, numIter) 
    errors.append(error) 
  print "There "+str(len(errors))+" test with "+str(numIter)+" interations in all!" 
  for i in range(numT): 
    print "The "+str(i+1)+"th"+" testError is:"+str(errors[i]) 
  print "Average testError: ", float(sum(errors))/len(errors) 
''''' 
data, labels = loadDataSet() 
weights0 = stoGradAscent0(array(data), labels) 
weights,errors = gradAscent(data, labels) 
weights1= stoGradAscent1(array(data), labels, 500) 
print weights 
plotBestFit(weights) 
print weights0 
weights00 = [] 
for w in weights0: 
  weights00.append([w]) 
plotBestFit(mat(weights00)) 
print weights1 
weights11=[] 
for w in weights1: 
  weights11.append([w]) 
plotBestFit(mat(weights11)) 
''' 
multiTest(r"horseColicTraining.txt",r"horseColicTest.txt",10,500)

总结

以上就是本文关于机器学习经典算法-logistic回归代码详解的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站:

如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

Python 相关文章推荐
Cython 三分钟入门教程
Sep 17 Python
python将文本分每两行一组并保存到文件
Mar 19 Python
python3 判断列表是一个空列表的方法
May 04 Python
解决python3读取Python2存储的pickle文件问题
Oct 25 Python
安装好Pycharm后如何配置Python解释器简易教程
Jun 28 Python
python数据类型可变不可变知识点总结
Mar 06 Python
浅谈Pycharm的项目文件名是红色的原因及解决方式
Jun 01 Python
一文解决django 2.2与mysql兼容性问题
Jul 15 Python
基于python判断字符串括号是否闭合{}[]()
Sep 21 Python
python 解决Windows平台上路径有空格的问题
Nov 10 Python
python spilt()分隔字符串的实现示例
May 21 Python
Python极值整数的边界探讨分析
Sep 15 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 #Python
python实现数据预处理之填充缺失值的示例
Dec 22 #Python
NetworkX之Prim算法(实例讲解)
Dec 22 #Python
Python实现控制台中的进度条功能代码
Dec 22 #Python
Python中的探索性数据分析(功能式)
Dec 22 #Python
Python反射用法实例简析
Dec 22 #Python
Python文本特征抽取与向量化算法学习
Dec 22 #Python
You might like
如何对PHP程序中的常见漏洞进行攻击(上)
2006/10/09 PHP
PHP积分兑换接口实例
2015/02/09 PHP
PHP易混淆知识整理笔记
2015/09/24 PHP
PHP实现JS中escape与unescape的方法
2016/07/11 PHP
使用Codeigniter重写insert的方法(推荐)
2017/03/23 PHP
接收键盘指令的脚本
2006/06/26 Javascript
得到文本框选中的文字,动态插入文字的js代码
2007/03/07 Javascript
使用jQuery fancybox插件打造一个实用的数据传输模态弹出窗体
2013/01/15 Javascript
js中 关于undefined和null的区别介绍
2013/04/16 Javascript
js禁止回车提交表单的示例代码
2013/12/23 Javascript
js中直接声明一个对象的方法
2014/08/10 Javascript
JavaScript实现网页加载进度条代码超简单
2015/09/21 Javascript
购物车前端开发(jQuery和bootstrap3)
2016/08/27 Javascript
jQuery实现简单的tab标签页效果
2016/09/12 Javascript
jQuery中$.grep() 过滤函数 数组过滤
2016/11/22 Javascript
jQuery插件HighCharts实现的2D回归直线散点效果示例【附demo源码下载】
2017/03/09 Javascript
vue.js删除动态绑定的radio的指定项
2017/06/02 Javascript
js canvas画布实现高斯模糊效果
2018/11/27 Javascript
vscode vue 文件模板的配置方法
2019/07/23 Javascript
js前端如何写一个精确的倒计时代码
2019/10/25 Javascript
vue的三种图片引入方式代码实例
2019/11/19 Javascript
[57:53]DOTA2上海特级锦标赛主赛事日 - 2 败者组第二轮#3OG VS VP
2016/03/03 DOTA
python使用fileinput模块实现逐行读取文件的方法
2015/04/29 Python
利用python微信库itchat实现微信自动回复功能
2017/05/18 Python
Django在pycharm下修改默认启动端口的方法
2019/07/26 Python
使用Python制作一个打字训练小工具
2019/10/01 Python
opencv python 图片读取与显示图片窗口未响应问题的解决
2020/04/24 Python
python 密码学示例——理解哈希(Hash)算法
2020/09/21 Python
全球最大的在线旅游公司:Expedia
2017/11/16 全球购物
纽约手袋品牌:KARA
2018/03/18 全球购物
澳大利亚香水在线商店:City Perfume
2020/09/02 全球购物
天猫某品牌专卖店运营计划书
2014/03/21 职场文书
全国爱眼日活动总结
2015/02/27 职场文书
保研推荐信格式
2015/03/25 职场文书
Redis如何一键部署脚本
2021/04/12 Redis
在python中实现导入一个需要传参的模块
2021/05/12 Python