机器学习经典算法-logistic回归代码详解


Posted in Python onDecember 22, 2017

一、算法简要

我们希望有这么一种函数:接受输入然后预测出类别,这样用于分类。这里,用到了数学中的sigmoid函数,sigmoid函数的具体表达式和函数图象如下:

机器学习经典算法-logistic回归代码详解

可以较为清楚的看到,当输入的x小于0时,函数值<0.5,将分类预测为0;当输入的x大于0时,函数值>0.5,将分类预测为1。

1.1 预测函数的表示

机器学习经典算法-logistic回归代码详解

1.2参数的求解

机器学习经典算法-logistic回归代码详解

二、代码实现

函数sigmoid计算相应的函数值;gradAscent实现的batch-梯度上升,意思就是在每次迭代中所有数据集都考虑到了;而stoGradAscent0中,则是将数据集中的示例都比那里了一遍,复杂度大大降低;stoGradAscent1则是对随机梯度上升的改进,具体变化是alpha每次变化的频率是变化的,而且每次更新参数用到的示例都是随机选取的。

from numpy import * 
import matplotlib.pyplot as plt 
def loadDataSet(): 
  dataMat = [] 
  labelMat = [] 
  fr = open('testSet.txt') 
  for line in fr.readlines(): 
    lineArr = line.strip('\n').split('\t') 
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])]) 
    labelMat.append(int(lineArr[2])) 
  fr.close() 
  return dataMat, labelMat 
def sigmoid(inX): 
  return 1.0/(1+exp(-inX)) 
def gradAscent(dataMatIn, classLabels): 
  dataMatrix = mat(dataMatIn) 
  labelMat = mat(classLabels).transpose() 
  m,n=shape(dataMatrix) 
  alpha = 0.001 
  maxCycles = 500 
  weights = ones((n,1)) 
  errors=[] 
  for k in range(maxCycles): 
    h = sigmoid(dataMatrix*weights) 
    error = labelMat - h 
    errors.append(sum(error)) 
    weights = weights + alpha*dataMatrix.transpose()*error 
  return weights, errors 
def stoGradAscent0(dataMatIn, classLabels): 
  m,n=shape(dataMatIn) 
  alpha = 0.01 
  weights = ones(n) 
  for i in range(m): 
    h = sigmoid(sum(dataMatIn[i]*weights)) 
    error = classLabels[i] - h  
    weights = weights + alpha*error*dataMatIn[i] 
  return weights 
def stoGradAscent1(dataMatrix, classLabels, numIter = 150): 
  m,n=shape(dataMatrix) 
  weights = ones(n) 
  for j in range(numIter): 
    dataIndex=range(m) 
    for i in range(m): 
      alpha= 4/(1.0+j+i)+0.01 
      randIndex = int(random.uniform(0,len(dataIndex))) 
      h = sigmoid(sum(dataMatrix[randIndex]*weights)) 
      error = classLabels[randIndex]-h 
      weights=weights+alpha*error*dataMatrix[randIndex] 
      del(dataIndex[randIndex]) 
    return weights 
def plotError(errs): 
  k = len(errs) 
  x = range(1,k+1) 
  plt.plot(x,errs,'g--') 
  plt.show() 
def plotBestFit(wei): 
  weights = wei.getA() 
  dataMat, labelMat = loadDataSet() 
  dataArr = array(dataMat) 
  n = shape(dataArr)[0] 
  xcord1=[] 
  ycord1=[] 
  xcord2=[] 
  ycord2=[] 
  for i in range(n):  
    if int(labelMat[i])==1: 
      xcord1.append(dataArr[i,1]) 
      ycord1.append(dataArr[i,2]) 
    else: 
      xcord2.append(dataArr[i,1]) 
      ycord2.append(dataArr[i,2]) 
  fig = plt.figure() 
  ax = fig.add_subplot(111) 
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s') 
  ax.scatter(xcord2, ycord2, s=30, c='green') 
  x = arange(-3.0,3.0,0.1) 
  y=(-weights[0]-weights[1]*x)/weights[2] 
  ax.plot(x,y) 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  plt.show() 
def classifyVector(inX, weights): 
  prob = sigmoid(sum(inX*weights)) 
  if prob>0.5: 
    return 1.0 
  else: 
    return 0 
def colicTest(ftr, fte, numIter): 
  frTrain = open(ftr) 
  frTest = open(fte) 
  trainingSet=[] 
  trainingLabels=[] 
  for line in frTrain.readlines(): 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    trainingSet.append(lineArr) 
    trainingLabels.append(float(currLine[21])) 
  frTrain.close() 
  trainWeights = stoGradAscent1(array(trainingSet),trainingLabels, numIter) 
  errorCount = 0 
  numTestVec = 0.0 
  for line in frTest.readlines(): 
    numTestVec += 1.0 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    if int(classifyVector(array(lineArr), trainWeights))!=int(currLine[21]): 
      errorCount += 1 
  frTest.close() 
  errorRate = (float(errorCount))/numTestVec 
  return errorRate 
def multiTest(ftr, fte, numT, numIter): 
  errors=[] 
  for k in range(numT): 
    error = colicTest(ftr, fte, numIter) 
    errors.append(error) 
  print "There "+str(len(errors))+" test with "+str(numIter)+" interations in all!" 
  for i in range(numT): 
    print "The "+str(i+1)+"th"+" testError is:"+str(errors[i]) 
  print "Average testError: ", float(sum(errors))/len(errors) 
''''' 
data, labels = loadDataSet() 
weights0 = stoGradAscent0(array(data), labels) 
weights,errors = gradAscent(data, labels) 
weights1= stoGradAscent1(array(data), labels, 500) 
print weights 
plotBestFit(weights) 
print weights0 
weights00 = [] 
for w in weights0: 
  weights00.append([w]) 
plotBestFit(mat(weights00)) 
print weights1 
weights11=[] 
for w in weights1: 
  weights11.append([w]) 
plotBestFit(mat(weights11)) 
''' 
multiTest(r"horseColicTraining.txt",r"horseColicTest.txt",10,500)

总结

以上就是本文关于机器学习经典算法-logistic回归代码详解的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站:

如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

Python 相关文章推荐
深入理解Python对Json的解析
Feb 14 Python
详解python函数传参是传值还是传引用
Jan 16 Python
Python socket实现的简单通信功能示例
Aug 21 Python
解决pycharm 误删掉项目文件的处理方法
Oct 22 Python
selenium跳过webdriver检测并模拟登录淘宝
Jun 12 Python
Python的numpy库下的几个小函数的用法(小结)
Jul 12 Python
pytorch 自定义参数不更新方式
Jan 06 Python
python torch.utils.data.DataLoader使用方法
Apr 02 Python
python实现简单学生信息管理系统
Apr 09 Python
Python读取excel文件中带公式的值的实现
Apr 17 Python
python自动打开浏览器下载zip并提取内容写入excel
Jan 04 Python
python中subplot大小的设置步骤
Jun 28 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 #Python
python实现数据预处理之填充缺失值的示例
Dec 22 #Python
NetworkX之Prim算法(实例讲解)
Dec 22 #Python
Python实现控制台中的进度条功能代码
Dec 22 #Python
Python中的探索性数据分析(功能式)
Dec 22 #Python
Python反射用法实例简析
Dec 22 #Python
Python文本特征抽取与向量化算法学习
Dec 22 #Python
You might like
总集篇&特番节目先行播出!《SAO Alicization War of Underworld》第2季度TV动画4月25日放送!
2020/03/06 日漫
php循环创建目录示例分享(php创建多级目录)
2014/03/04 PHP
PHP附件下载中文名称乱码的解决方法
2015/12/17 PHP
Yii快速入门经典教程
2015/12/28 PHP
php中上传文件的的解决方案
2018/09/25 PHP
PHP parse_ini_file函数的应用与扩展操作示例
2019/01/07 PHP
简单实用的PHP文本缓存类实例
2019/03/22 PHP
Swoole 5将移除自动添加Event::wait()特性详解
2019/07/10 PHP
thinkPHP和onethink微信支付插件分享
2019/08/11 PHP
JavaScript入门教程 Cookies
2009/01/31 Javascript
对 lightbox JS 图片控件进行了一下改造, 使其他支持复杂的图片说明
2010/03/20 Javascript
精通Javascript系列之数值计算
2011/06/07 Javascript
jquery浏览器滚动加载技术实现方案
2014/06/03 Javascript
node.js中的require使用详解
2014/12/15 Javascript
轻松创建nodejs服务器(10):处理POST请求
2014/12/18 NodeJs
kindeditor修复会替换script内容的问题
2015/04/03 Javascript
基于JavaScript操作DOM常用的API小结
2015/12/01 Javascript
JS中Eval解析JSON字符串的一个小问题
2016/02/21 Javascript
微信小程序 本地数据读取实例
2017/04/27 Javascript
JS使用遮罩实现点击某区域以外时弹窗的弹出与关闭功能示例
2018/07/31 Javascript
create-react-app安装出错问题解决方法
2018/09/04 Javascript
Python计算三维矢量幅度的方法
2015/06/15 Python
详解字典树Trie结构及其Python代码实现
2016/06/03 Python
Python使用os.listdir()和os.walk()获取文件路径与文件下所有目录的方法
2019/04/01 Python
Python双链表原理与实现方法详解
2020/02/22 Python
python 已知平行四边形三个点,求第四个点的案例
2020/04/12 Python
pycharm配置安装autopep8自动规范代码的实现
2021/03/02 Python
乔丹诺(Giordano)酒庄德国官网:找到最好的意大利葡萄酒
2017/12/28 全球购物
送给程序员的20个Java集合面试问题
2014/08/06 面试题
实习生自荐信范文
2013/11/13 职场文书
岗位竞聘演讲稿
2014/01/10 职场文书
九年级数学教学反思
2014/02/02 职场文书
实习单位证明范例
2014/11/17 职场文书
小学运动会报道稿
2015/07/22 职场文书
2016年记者节感言
2015/12/08 职场文书
母婴行业实体、电商模式全面解析
2019/08/01 职场文书