python K近邻算法的kd树实现


Posted in Python onSeptember 06, 2018

k近邻算法的介绍

k近邻算法是一种基本的分类和回归方法,这里只实现分类的k近邻算法。

k近邻算法的输入为实例的特征向量,对应特征空间的点;输出为实例的类别,可以取多类。

k近邻算法不具有显式的学习过程,实际上k近邻算法是利用训练数据集对特征向量空间进行划分。将划分的空间模型作为其分类模型。

k近邻算法的三要素

  • k值的选择:即分类决策时选择k个最近邻实例;
  • 距离度量:即预测实例点和训练实例点间的距离,一般使用L2距离即欧氏距离;
  • 分类决策规则。

下面对三要素进行一下说明:

1.欧氏距离即欧几里得距离,高中数学中用来计算点和点间的距离公式;

2.k值选择:k值选择会对k近邻法结果产生重大影响,如果选择较小的k值,相当于在较小的邻域中训练实例进行预测,这样有点是“近似误差”会减小,即只与输入实例较近(相似)的训练实例才会起作用,缺点是“估计误差”会增大,即对近邻的实例点很敏感。而k值过大则相反。实际中取较小的k值通过交叉验证的方法取最优k值。

3.k近邻法的分类决策规则往往采用多数表决的方式,这等价于“经验风险最小化”。

k近邻算法的实现:kd树

实现k近邻法是要考虑的主要问题是如何退训练数据进行快速的k近邻搜索,当训练实例数很大是显然通过一般的线性搜索方式效率低下,因此为了提高搜索效率,需要构造特殊的数据结构对训练实例进行存储。kd树就是一种不错的数据结构,可以大大提高搜索效率。

本质商kd树是对k维空间的一个划分,构造kd树相当与使用垂直于坐标轴的超平面将k维空间进行切分,构造一系列的超矩形,kd树的每一个结点对应一个这样的超矩形。

kd树本质上是一棵二叉树,当通过一定规则构造是他是平衡的。

下面是过早kd树的算法:

  • 开始:构造根结点,根节点对应包含所有训练实例的k为空间。 选择第1维为坐标轴,以所有训练实例的第一维数据的中位数为切分点,将根结点对应的超矩形切分为两个子区域。由根结点生成深度为1的左右子结点,左结点对应第一维坐标小于切分点的子区域,右子结点对应第一位坐标大于切分点的子区域。
  • 重复:对深度为j的结点选择第l维为切分坐标轴,l=j(modk)+1,以该区域中所有训练实例的第l维的中位数为切分点,重复第一步。
  • 直到两个子区域没有实例存在时停止。形成kd树。

以下是kd树的python实现

准备工作

#读取数据准备
def file2matrix(filename):
  fr = open(filename)
  returnMat = []     #样本数据矩阵
  for line in fr.readlines():
    line = line.strip().split('\t')
    returnMat.append([float(line[0]),float(line[1]),float(line[2]),float(line[3])])
  return returnMat
  
#将数据归一化,避免数据各维度间的差异过大
def autoNorm(data):
  #将data数据和类别拆分
  data,label = np.split(data,[3],axis=1)
  minVals = data.min(0)   #data各列的最大值
  maxVals = data.max(0)    #data各列的最小值
  ranges = maxVals - minVals
  normDataSet = np.zeros(np.shape(data))
  m = data.shape[0]
  #tile函数将变量内容复制成输入矩阵同样大小的矩阵
  normDataSet = data - np.tile(minVals,(m,1))    
  normDataSet = normDataSet/np.tile(ranges,(m,1))
  #拼接
  normDataSet = np.hstack((normDataSet,label))
  return normDataSet
//数据实例
40920  8.326976  0.953952  3
14488  7.153469  1.673904  2
26052  1.441871  0.805124  1
75136  13.147394  0.428964  1
38344  1.669788  0.134296  1
72993  10.141740  1.032955  1
35948  6.830792  1.213192  3
42666  13.276369  0.543880  3
67497  8.631577  0.749278  1
35483  12.273169  1.508053  3
//每一行是一个数据实例,前三维是数据值,第四维是类别标记

树结构定义

#构建kdTree将特征空间划分
class kd_tree:
  """
  定义结点
  value:节点值
  dimension:当前划分的维数
  left:左子树
  right:右子树
  """
  def __init__(self, value):
    self.value = value
    self.dimension = None    #记录划分的维数
    self.left = None
    self.right = None
  
  def setValue(self, value):
    self.value = value
  
  #类似Java的toString()方法
  def __str__(self):
    return str(self.value)

kd树构造

def creat_kdTree(dataIn, k, root, deep):
  """
  data:要划分的特征空间(即数据集)
  k:表示要选择k个近邻
  root:树的根结点
  deep:结点的深度
  """
  #选择x(l)(即为第l个特征)为坐标轴进行划分,找到x(l)的中位数进行划分
#   x_L = data[:,deep%k]    #这里选取第L个特征的所有数据组成一个列表
  #获取特征值中位数,这里是难点如果numpy没有提供的话
  
  if(dataIn.shape[0]>0):   #如果该区域还有实例数据就继续
    dataIn = dataIn[dataIn[:,int(deep%k)].argsort()]    #numpy的array按照某列进行排序
    data1 = None; data2 = None
    #拿取根据xL排序的中位数的数据作为该子树根结点的value
    if(dataIn.shape[0]%2 == 0):   #该数据集有偶数个数据
      mid = int(dataIn.shape[0]/2)
      root = kd_tree(dataIn[mid,:])
      root.dimension = deep%k
      dataIn = np.delete(dataIn,mid, axis = 0)
      data1,data2 = np.split(dataIn,[mid], axis=0) 
      #mid行元素分到data2中,删除放到根结点中
    elif(dataIn.shape[0]%2 == 1):
      mid = int((dataIn.shape[0]+1)/2 - 1)  #这里出现递归溢出,当shape为(1,4)时出现,原因是np.delete时没有赋值给dataIn
      root = kd_tree(dataIn[mid,:])
      root.dimension = deep%k
      dataIn = np.delete(dataIn,mid, axis = 0)
      data1,data2 = np.split(dataIn,[mid], axis=0) #mid行元素分到data1中,删除放到根结点中
    #深度加一
    deep+=1
    #递归构造子树
    #这里犯了严重错误,递归调用是将root传递进去,造成程序混乱,应该给None
    root.left = creat_kdTree(data1, k, None, deep)
    root.right = creat_kdTree(data2, k, None, deep)
  return root

前序遍历测试

#前序遍历kd树
def preorder(kd_tree,i):
  print(str(kd_tree.value)+" :"+str(kd_tree.dimension)+":"+str(i))
  if kd_tree.left != None:
    preorder(kd_tree.left,i+1)
  if kd_tree.right != None:
    preorder(kd_tree.right,i+1)

kd树的最近邻搜索

最近邻搜索算法,k近邻搜索在此基础上实现

原理:首先找到包含目标点的叶节点;然后从该也结点出发,一次退回到父节点,不断查找与目标点最近的结点,当确定不可能存在更近的结点是停止。

def findClosest(kdNode,closestPoint,x,minDis,i=0):
  """
  这里存在一个问题,当传递普通的不可变对象minDis时,递归退回第一次找到
  最端距离前,minDis改变,最后结果混乱,这里传递一个可变对象进来。
  kdNode:是构造好的kd树。
  closestPoint:是存储最近点的可变对象,这里是array
  x:是要预测的实例
  minDis:是当前最近距离。
  """
  if kdNode == None:
    return
  #计算欧氏距离
  curDis = (sum((kdNode.value[0:3]-x[0:3])**2))**0.5
  if minDis[0] < 0 or curDis < minDis[0] :
    i+=1
    minDis[0] = curDis 
    closestPoint[0] = kdNode.value[0]
    closestPoint[1] = kdNode.value[1]
    closestPoint[2] = kdNode.value[2]
    closestPoint[3] = kdNode.value[3]
    print(str(closestPoint)+" : "+str(i)+" : "+str(minDis))
  #递归查找叶节点
  if kdNode.value[kdNode.dimension] >= x[kdNode.dimension]:
    findClosest(kdNode.left,closestPoint,x,minDis,i)
  else:
    findClosest(kdNode.right, closestPoint, x, minDis,i) 
  #计算测试点和分隔超平面的距离,如果相交进入另一个叶节点重复
  rang = abs(x[kdNode.dimension] - kdNode.value[kdNode.dimension])
  if rang > minDis[0] :
    return
  if kdNode.value[kdNode.dimension] >= x[kdNode.dimension]:
    findClosest(kdNode.right,closestPoint,x,minDis,i)
  else:
    findClosest(kdNode.left, closestPoint, x, minDis,i)

测试:

data = file2matrix("datingTestSet2.txt")
data = np.array(data)
normDataSet = autoNorm(data)
sys.setrecursionlimit(10000)      #设置递归深度为10000
trainSet,testSet = np.split(normDataSet,[900],axis=0) 
kdTree = creat_kdTree(trainSet, 3, None, 0)
newData = testSet[1,0:3]
closestPoint = np.zeros(4)
minDis = np.array([-1.0])
findClosest(kdTree, closestPoint, newData, minDis)
print(closestPoint)
print(testSet[1,:])
print(minDis)

测试结果

[0.35118819 0.43961918 0.67110669 3.        ] : 1 : [0.40348346]
[0.11482037 0.13448927 0.48293309 2.        ] : 2 : [0.30404792]
[0.12227055 0.07902201 0.57826697 2.        ] : 3 : [0.22272422]
[0.0645755  0.10845299 0.83274698 2.        ] : 4 : [0.07066192]
[0.10020488 0.15196271 0.76225551 2.        ] : 5 : [0.02546591]
[0.10020488 0.15196271 0.76225551 2.        ]
[0.08959933 0.15442555 0.78527657 2.        ]
[0.02546591]

k近邻搜索实现

在最近邻的基础上进行改进得到:

这里的closestPoint和minDis合并,一同处理

#k近邻搜索
def findKNode(kdNode, closestPoints, x, k):
  """
  k近邻搜索,kdNode是要搜索的kd树
  closestPoints:是要搜索的k近邻点集合,将minDis放入closestPoints最后一列合并
  x:预测实例
  minDis:是最近距离
  k:是选择k个近邻
  """
  if kdNode == None:
    return
  #计算欧式距离
  curDis = (sum((kdNode.value[0:3]-x[0:3])**2))**0.5
  #将closestPoints按照minDis列排序,这里存在一个问题,排序后返回一个新对象
  #不能将其直接赋值给closestPoints
  tempPoints = closestPoints[closestPoints[:,4].argsort()]
  for i in range(k):
    closestPoints[i] = tempPoints[i]
  #每次取最后一行元素操作
  if closestPoints[k-1][4] >=10000 or closestPoints[k-1][4] > curDis:
    closestPoints[k-1][4] = curDis
    closestPoints[k-1,0:4] = kdNode.value 
    
  #递归搜索叶结点
  if kdNode.value[kdNode.dimension] >= x[kdNode.dimension]:
    findKNode(kdNode.left, closestPoints, x, k)
  else:
    findKNode(kdNode.right, closestPoints, x, k)
  #计算测试点和分隔超平面的距离,如果相交进入另一个叶节点重复
  rang = abs(x[kdNode.dimension] - kdNode.value[kdNode.dimension])
  if rang > closestPoints[k-1][4]:
    return
  if kdNode.value[kdNode.dimension] >= x[kdNode.dimension]:
    findKNode(kdNode.right, closestPoints, x, k)
  else:
    findKNode(kdNode.left, closestPoints, x, k)

测试

data = file2matrix("datingTestSet2.txt")
data = np.array(data)
normDataSet = autoNorm(data)
sys.setrecursionlimit(10000)      #设置递归深度为10000
trainSet,testSet = np.split(normDataSet,[900],axis=0) 
kdTree = creat_kdTree(trainSet, 3, None, 0)
newData = testSet[1,0:3]
print("预测实例点:"+str(newData))
closestPoints = np.zeros((3,5))     #初始化参数
closestPoints[:,4] = 10000.0      #给minDis列赋值
findKNode(kdTree, closestPoints, newData, 3)
print("k近邻结果:"+str(closestPoints))

测试结果

预测实例点:[0.08959933 0.15442555 0.78527657]

k近邻结果:[[0.10020488 0.15196271 0.76225551 2.         0.02546591]
 [0.10664709 0.13172159 0.83777837 2.         0.05968697]
 [0.09616206 0.20475001 0.75047289 2.         0.06153793]]

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 文件与目录操作
Dec 24 Python
Python实现比较两个文件夹中代码变化的方法
Jul 10 Python
Python 3.6 性能测试框架Locust安装及使用方法(详解)
Oct 11 Python
Python+matplotlib+numpy绘制精美的条形统计图
Jan 02 Python
利用python 更新ssh 远程代码 操作远程服务器的实现代码
Feb 08 Python
解决Pycharm后台indexing导致不能run的问题
Jun 27 Python
解决Django删除migrations文件夹中的文件后出现的异常问题
Aug 31 Python
Python操作qml对象过程详解
Sep 26 Python
python使用Geany编辑器配置方法
Feb 21 Python
python 删除excel表格重复行,数据预处理操作
Jul 06 Python
详解KMP算法以及python如何实现
Sep 18 Python
关于探究python中sys.argv时遇到的问题详解
Feb 23 Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
Python多线程编程之多线程加锁操作示例
Sep 06 #Python
python中将\\uxxxx转换为Unicode字符串的方法
Sep 06 #Python
Python json模块dumps、loads操作示例
Sep 06 #Python
Python 字符串换行的多种方式
Sep 06 #Python
Python使用logging模块实现打印log到指定文件的方法
Sep 05 #Python
Python使用try except处理程序异常的三种常用方法分析
Sep 05 #Python
You might like
利用static实现表格的颜色隔行显示
2006/10/09 PHP
目录,文件操作详谈―PHP
2006/11/25 PHP
关于php 高并发解决的一点思路
2017/04/16 PHP
PHP连接SQL Server的方法分析【基于thinkPHP5.1框架】
2019/05/06 PHP
Yii2.0框架模型添加/修改/删除数据操作示例
2019/07/18 PHP
JavaScript的Function详细
2006/11/14 Javascript
在修改准备发的批量美化select+可修改select时,在非IE下发现了几个问题
2007/01/09 Javascript
Tinymce+jQuery.Validation使用产生的BUG
2010/03/29 Javascript
javascript实现汉字转拼音代码分享
2015/04/20 Javascript
简述JavaScript对传统文档对象模型的支持
2015/06/16 Javascript
微信小程序的动画效果详解
2017/01/18 Javascript
jQuery插件HighCharts实现的2D对数饼图效果示例【附demo源码下载】
2017/03/09 Javascript
Taro集成Redux快速上手的方法示例
2018/06/21 Javascript
node.js使用redis储存session的方法
2018/09/26 Javascript
简述vue-cli中chainWebpack的使用方法
2019/07/30 Javascript
微信小程序选择图片控件
2021/01/19 Javascript
基于Python_脚本CGI、特点、应用、开发环境(详解)
2017/05/23 Python
python读取中文txt文本的方法
2018/04/12 Python
Python3之简单搭建自带服务器的实例讲解
2018/06/04 Python
详解pandas安装若干异常及解决方案总结
2019/01/10 Python
python获取交互式ssh shell的方法
2019/02/14 Python
简单了解为什么python函数后有多个括号
2019/12/19 Python
使用Python实现将多表分批次从数据库导出到Excel
2020/05/15 Python
深入了解Python 方法之类方法 &amp; 静态方法
2020/08/17 Python
Django创建一个后台的基本步骤记录
2020/10/02 Python
amazeui模态框弹出后立马消失并刷新页面
2020/08/19 HTML / CSS
Otticanet意大利:最顶尖的世界名牌眼镜, 能得到打折季的价格
2019/03/10 全球购物
期末学生评语大全
2014/04/24 职场文书
教师节宣传方案
2014/05/23 职场文书
要账委托书范本
2014/09/15 职场文书
国际商务专业毕业生自我鉴定2014
2014/09/27 职场文书
销售代理协议书
2014/09/30 职场文书
个人作风建设剖析材料
2014/10/11 职场文书
信访维稳工作汇报
2014/10/27 职场文书
少年雷锋观后感
2015/06/10 职场文书
VUE中的v-if与v-show区别介绍
2022/03/13 Vue.js