机器学习经典算法-logistic回归代码详解


Posted in Python onDecember 22, 2017

一、算法简要

我们希望有这么一种函数:接受输入然后预测出类别,这样用于分类。这里,用到了数学中的sigmoid函数,sigmoid函数的具体表达式和函数图象如下:

机器学习经典算法-logistic回归代码详解

可以较为清楚的看到,当输入的x小于0时,函数值<0.5,将分类预测为0;当输入的x大于0时,函数值>0.5,将分类预测为1。

1.1 预测函数的表示

机器学习经典算法-logistic回归代码详解

1.2参数的求解

机器学习经典算法-logistic回归代码详解

二、代码实现

函数sigmoid计算相应的函数值;gradAscent实现的batch-梯度上升,意思就是在每次迭代中所有数据集都考虑到了;而stoGradAscent0中,则是将数据集中的示例都比那里了一遍,复杂度大大降低;stoGradAscent1则是对随机梯度上升的改进,具体变化是alpha每次变化的频率是变化的,而且每次更新参数用到的示例都是随机选取的。

from numpy import * 
import matplotlib.pyplot as plt 
def loadDataSet(): 
  dataMat = [] 
  labelMat = [] 
  fr = open('testSet.txt') 
  for line in fr.readlines(): 
    lineArr = line.strip('\n').split('\t') 
    dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])]) 
    labelMat.append(int(lineArr[2])) 
  fr.close() 
  return dataMat, labelMat 
def sigmoid(inX): 
  return 1.0/(1+exp(-inX)) 
def gradAscent(dataMatIn, classLabels): 
  dataMatrix = mat(dataMatIn) 
  labelMat = mat(classLabels).transpose() 
  m,n=shape(dataMatrix) 
  alpha = 0.001 
  maxCycles = 500 
  weights = ones((n,1)) 
  errors=[] 
  for k in range(maxCycles): 
    h = sigmoid(dataMatrix*weights) 
    error = labelMat - h 
    errors.append(sum(error)) 
    weights = weights + alpha*dataMatrix.transpose()*error 
  return weights, errors 
def stoGradAscent0(dataMatIn, classLabels): 
  m,n=shape(dataMatIn) 
  alpha = 0.01 
  weights = ones(n) 
  for i in range(m): 
    h = sigmoid(sum(dataMatIn[i]*weights)) 
    error = classLabels[i] - h  
    weights = weights + alpha*error*dataMatIn[i] 
  return weights 
def stoGradAscent1(dataMatrix, classLabels, numIter = 150): 
  m,n=shape(dataMatrix) 
  weights = ones(n) 
  for j in range(numIter): 
    dataIndex=range(m) 
    for i in range(m): 
      alpha= 4/(1.0+j+i)+0.01 
      randIndex = int(random.uniform(0,len(dataIndex))) 
      h = sigmoid(sum(dataMatrix[randIndex]*weights)) 
      error = classLabels[randIndex]-h 
      weights=weights+alpha*error*dataMatrix[randIndex] 
      del(dataIndex[randIndex]) 
    return weights 
def plotError(errs): 
  k = len(errs) 
  x = range(1,k+1) 
  plt.plot(x,errs,'g--') 
  plt.show() 
def plotBestFit(wei): 
  weights = wei.getA() 
  dataMat, labelMat = loadDataSet() 
  dataArr = array(dataMat) 
  n = shape(dataArr)[0] 
  xcord1=[] 
  ycord1=[] 
  xcord2=[] 
  ycord2=[] 
  for i in range(n):  
    if int(labelMat[i])==1: 
      xcord1.append(dataArr[i,1]) 
      ycord1.append(dataArr[i,2]) 
    else: 
      xcord2.append(dataArr[i,1]) 
      ycord2.append(dataArr[i,2]) 
  fig = plt.figure() 
  ax = fig.add_subplot(111) 
  ax.scatter(xcord1, ycord1, s=30, c='red', marker='s') 
  ax.scatter(xcord2, ycord2, s=30, c='green') 
  x = arange(-3.0,3.0,0.1) 
  y=(-weights[0]-weights[1]*x)/weights[2] 
  ax.plot(x,y) 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  plt.show() 
def classifyVector(inX, weights): 
  prob = sigmoid(sum(inX*weights)) 
  if prob>0.5: 
    return 1.0 
  else: 
    return 0 
def colicTest(ftr, fte, numIter): 
  frTrain = open(ftr) 
  frTest = open(fte) 
  trainingSet=[] 
  trainingLabels=[] 
  for line in frTrain.readlines(): 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    trainingSet.append(lineArr) 
    trainingLabels.append(float(currLine[21])) 
  frTrain.close() 
  trainWeights = stoGradAscent1(array(trainingSet),trainingLabels, numIter) 
  errorCount = 0 
  numTestVec = 0.0 
  for line in frTest.readlines(): 
    numTestVec += 1.0 
    currLine = line.strip('\n').split('\t') 
    lineArr=[] 
    for i in range(21): 
      lineArr.append(float(currLine[i])) 
    if int(classifyVector(array(lineArr), trainWeights))!=int(currLine[21]): 
      errorCount += 1 
  frTest.close() 
  errorRate = (float(errorCount))/numTestVec 
  return errorRate 
def multiTest(ftr, fte, numT, numIter): 
  errors=[] 
  for k in range(numT): 
    error = colicTest(ftr, fte, numIter) 
    errors.append(error) 
  print "There "+str(len(errors))+" test with "+str(numIter)+" interations in all!" 
  for i in range(numT): 
    print "The "+str(i+1)+"th"+" testError is:"+str(errors[i]) 
  print "Average testError: ", float(sum(errors))/len(errors) 
''''' 
data, labels = loadDataSet() 
weights0 = stoGradAscent0(array(data), labels) 
weights,errors = gradAscent(data, labels) 
weights1= stoGradAscent1(array(data), labels, 500) 
print weights 
plotBestFit(weights) 
print weights0 
weights00 = [] 
for w in weights0: 
  weights00.append([w]) 
plotBestFit(mat(weights00)) 
print weights1 
weights11=[] 
for w in weights1: 
  weights11.append([w]) 
plotBestFit(mat(weights11)) 
''' 
multiTest(r"horseColicTraining.txt",r"horseColicTest.txt",10,500)

总结

以上就是本文关于机器学习经典算法-logistic回归代码详解的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站:

如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

Python 相关文章推荐
python采用django框架实现支付宝即时到帐接口
May 17 Python
python机器学习之决策树分类详解
Dec 20 Python
caffe binaryproto 与 npy相互转换的实例讲解
Jul 09 Python
解决Python中pandas读取*.csv文件出现编码问题
Jul 12 Python
pytorch自定义初始化权重的方法
Aug 17 Python
Python制作词云图代码实例
Sep 09 Python
Python Collatz序列实现过程解析
Oct 12 Python
使用Keras中的ImageDataGenerator进行批次读图方式
Jun 17 Python
pytorch 限制GPU使用效率详解(计算效率)
Jun 27 Python
keras的三种模型实现与区别说明
Jul 03 Python
序列化Python对象的方法
Aug 01 Python
Python实现Excel文件的合并(以新冠疫情数据为例)
Mar 20 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 #Python
python实现数据预处理之填充缺失值的示例
Dec 22 #Python
NetworkX之Prim算法(实例讲解)
Dec 22 #Python
Python实现控制台中的进度条功能代码
Dec 22 #Python
Python中的探索性数据分析(功能式)
Dec 22 #Python
Python反射用法实例简析
Dec 22 #Python
Python文本特征抽取与向量化算法学习
Dec 22 #Python
You might like
PHP面向对象三大特点学习(充分理解抽象、封装、继承、多态)
2012/05/07 PHP
php版银联支付接口开发简明教程
2016/10/14 PHP
Yii2框架RESTful API 格式化响应,授权认证和速率限制三部分详解
2016/11/10 PHP
php判断是否为ajax请求的方法
2016/11/29 PHP
全面解析PHP面向对象的三大特征
2017/06/10 PHP
jquery按回车提交数据的代码示例
2013/11/05 Javascript
JS中解决谷歌浏览器记住密码输入框颜色改变功能
2017/02/13 Javascript
AngularJS执行流程详解
2017/02/17 Javascript
浅谈如何使用 webpack 优化资源
2017/10/20 Javascript
react-native android状态栏的实现
2018/06/15 Javascript
JS实现盒子跟着鼠标移动及键盘方向键控制盒子移动效果示例
2019/01/29 Javascript
微信小程序 image组件遇到的问题
2019/05/28 Javascript
layui按条件隐藏表格列的实例
2019/09/19 Javascript
JS代码实现页面切换效果
2021/01/10 Javascript
[56:48]FNATIC vs EG 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/16 DOTA
Python中os.path用法分析
2015/01/15 Python
调试Python程序代码的几种方法总结
2015/04/28 Python
Python实现判断一个字符串是否包含子串的方法总结
2017/11/21 Python
Python判断对象是否为文件对象(file object)的三种方法示例
2019/04/26 Python
什么是python的列表推导式
2020/05/26 Python
关于python3.9安装wordcloud出错的问题及解决办法
2020/11/02 Python
canvas三角函数模拟水波效果的示例代码
2018/07/03 HTML / CSS
HTML5资源预加载(Link prefetch)详细介绍(给你的网页加速)
2014/05/07 HTML / CSS
HTML5页面直接调用百度地图API获取当前位置直接导航目的地的实现代码
2018/03/02 HTML / CSS
单身旅行者的单身假期:Just You
2018/04/08 全球购物
eVitamins日本:在线购买折扣维生素、补品和草药
2019/04/04 全球购物
50道外企软件测试面试题
2014/08/18 面试题
中间件分为哪几类
2012/03/14 面试题
个人现实表现材料
2014/02/04 职场文书
个人贷款承诺书
2014/03/28 职场文书
2014五一国际劳动节活动总结范文
2014/04/14 职场文书
活动总结报告格式
2014/05/09 职场文书
2016廉洁教育心得体会
2016/01/20 职场文书
文案策划岗位个人自我评价(范文)
2019/08/08 职场文书
matlab xlabel位置的设置方式
2021/05/21 Python
使用Springboot实现健身房管理系统
2021/07/01 Java/Android