使用python实现kNN分类算法


Posted in Python onOctober 16, 2019

k-近邻算法是基本的机器学习算法,算法的原理非常简单:

输入样本数据后,计算输入样本和参考样本之间的距离,找出离输入样本距离最近的k个样本,找出这k个样本中出现频率最高的类标签作为输入样本的类标签,很直观也很简单,就是和参考样本集中的样本做对比。下面讲一讲用python实现kNN算法的方法,这里主要用了python中常用的numpy模块,采用的数据集是来自UCI的一个数据集,总共包含1055个样本,每个样本有41个real的属性和一个类标签,包含两类(RB和NRB)。我选取800条样本作为参考样本,剩下的作为测试样本。

下面是分类器的python代码:

'''
kNNClassify(inputAttr, trainSetPath = '', lenOfInstance = 42, startAttr = 0, stopAttr = 40, posOfClass = 41, numOfRefSamples = 5)函数
参数:
inputAttr:输入的属性向量
trainSetPath:字符串,保存训练样本的路径
lenOfInstance:样本向量的维数
startAttr:属性向量在整个样本向量中的起始下标
stopAttr:属性向量在整个样本向量中的终止下标
posOfClass:类标签的在整个样本向量中的下标
numOfClSamples:选出来进行投票的样本个数
返回值:
类标签
'''
 
def kNNClassify(inputAttr, trainSetPath = '', lenOfInstance = 42, startAttr = 0, stopAttr = 40, posOfClass = 41, numOfRefSamples = 5):
  fr = open(trainSetPath)
  strOfLine = fr.readline()
  arrayOfLine = numpy.array([0.] * lenOfInstance)
  refSamples = numpy.array([[-1., 0.]] * numOfRefSamples)
  
  #找出属性中的最大值和最小值,用于归一化
  maxAttr, minAttr = kNNFunction.dataNorm(trainSetPath = trainSetPath, lenOfInstance = lenOfInstance)
  maxAttr = maxAttr[(numpy.array(range(stopAttr - startAttr + 1)) 
            + numpy.array([startAttr] * (stopAttr - startAttr + 1)))]
  minAttr = minAttr[(numpy.array(range(stopAttr - startAttr + 1)) 
            + numpy.array([startAttr] * (stopAttr - startAttr + 1)))]
  attrRanges = maxAttr - minAttr
  
  inputAttr = inputAttr[(numpy.array(range(stopAttr - startAttr + 1)) 
              + numpy.array([startAttr] * (stopAttr - startAttr + 1)))]
  inputAttr = (inputAttr - minAttr) / attrRanges       #归一化
  
  #将字符串转换为向量并进行计算找出离输入样本距离最近的numOfRefSamples个参考样本
  while strOfLine != '' :
    strOfLine = strOfLine.strip()
    strOfLine = strOfLine.split(';')
    
    abandonOrNot = False
    for i in range(lenOfInstance) :
      if strOfLine[i] == 'RB' :
        arrayOfLine[i] = 1.0
      elif strOfLine[i] == 'NRB' :
        arrayOfLine[i] = 0.0
      elif strOfLine[i] != '?' :             #没有发现缺失值
        arrayOfLine[i] = float(strOfLine[i])     
        abandonOrNot = False
      else :                      #发现缺失值
        abandonOrNot = True
        break
    
    if abandonOrNot == True :
      strOfLine = fr.readline()
      continue
    else :
      attr = arrayOfLine[(numpy.array(range(stopAttr - startAttr + 1)) 
                + numpy.array([startAttr] * (stopAttr - startAttr + 1)))]
      attr = (attr - minAttr) / attrRanges      #归一化    
      classLabel = arrayOfLine[posOfClass]
      distance = (attr - inputAttr) ** 2
      distance = distance.sum(axis = 0)
      distance = distance ** 0.5
      disAndLabel = numpy.array([distance, classLabel])
      refSamples = kNNFunction.insertItem(refSamples, numOfRefSamples, disAndLabel)
      strOfLine = fr.readline()
      continue
    
  #统计每个类标签出现的次数
  classCount = {}
  for i in range(numOfRefSamples) :
    voteLabel = refSamples[i][1]
    classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
  sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
  
  return int(sortedClassCount[0][0])

实现步骤为:读取一条样本,转换为向量,计算这条样本与输入样本的距离,将样本插入到refSamples数组中,当然这里的样本只是一个包含两个元素的数组(距离和类标签),而refSamples数组用于保存离输入样本最近的numOfRefSamples个参考样本。当所有样本都读完之后,就找出了离输入样本最近的numOfRefSamples个参考样本。其中kNNFunction.insertItem函数实现的是将得到的新样本插入到refSamples数组中,主要采用类似冒泡排序的方法,实现代码如下:

'''
insertItem(refSamples, numOfRefSamples, disAndLabel)函数
功能:
在参考样本集中插入新样本,这里的样本是一个包含两个数值的list,第一个是距离,第二个是类标签
在参考样本集中按照距离从小到大排列
参数:
refSamples:参考样本集
numOfRefSamples:参考样本集中的样本总数
disAndLabel:需要插入的样本数
'''
 
def insertItem(refSamples, numOfRefSamples, disAndLabel):
  if (disAndLabel[0] < refSamples[numOfRefSamples - 1][0]) or (refSamples[numOfRefSamples - 1][0] < 0) :
    refSamples[numOfRefSamples - 1] = disAndLabel
    for i in (numpy.array([numOfRefSamples - 2] * (numOfRefSamples - 1)) - numpy.array(range(numOfRefSamples -1))) :
      if (refSamples[i][0] > refSamples[i + 1][0]) or (refSamples[i][0] < 0) :
        tempSample = list(refSamples[i])
        refSamples[i] = refSamples[i + 1]
        refSamples[i + 1] = tempSample
      else :
        break
    return refSamples
  else :
    return refSamples

另外,需要注意的一点是要对输入样本的各条属性进行归一化处理。毕竟不同的属性的取值范围不一样,取值范围大的属性在计算距离的过程中所起到的作用自然就要大一些,所以有必要把所有属性映射到0和1之间。这就需要计算每个属性的最大值和最小值,方法就是遍历整个参考样本集,找出最大值和最小样本,这里用dataNorm函数是实现:

'''
归一化函数,返回归一化向量
'''
def dataNorm(trainSetPath = '', lenOfInstance = 42):
  fr = open(trainSetPath)
  strOfLine = fr.readline()                #从文件中读取的一行字符串
  arrayOfLine = numpy.array([0.] * lenOfInstance)       #用来保存与字符串对应的数组
  maxAttr = numpy.array(['NULL'] * lenOfInstance)       #用来保存每条属性的最大值
  minAttr = numpy.array(['NULL'] * lenOfInstance)       #用来保存每条属性的最小值
  
  while strOfLine != '' :
    strOfLine = strOfLine.strip()            #去掉字符串末尾的换行符
    strOfLine = strOfLine.split(';')           #将字符串按逗号分割成字符串数组
    
    abandonOrNot = False                 
    for i in range(lenOfInstance) :
      if strOfLine[i] == 'RB' :
        arrayOfLine[i] = 1.0
      elif strOfLine[i] == 'NRB' :
        arrayOfLine[i] = 0.0
      elif strOfLine[i] != '?' :             #没有发现缺失值
        arrayOfLine[i] = float(strOfLine[i])     
        abandonOrNot = False
      else :                      #发现缺失值
        abandonOrNot = True
        break
    
    if abandonOrNot == True :              #存在缺失值,丢弃
      strOfLine = fr.readline()
      continue
    else :                        #没有缺失值,保留
      if maxAttr[0] == 'NULL' or minAttr[0] == 'NULL' :     #maxAttr和minAttr矩阵是空的
        maxAttr = numpy.array(arrayOfLine)
        minAttr = numpy.array(arrayOfLine)
        strOfLine = fr.readline()
        continue
      for i in range(lenOfInstance) :
        if maxAttr[i] < arrayOfLine[i] :
          maxAttr[i] = float(arrayOfLine[i])
        if minAttr[i] > arrayOfLine[i] :
          minAttr[i] = float(arrayOfLine[i])
      strOfLine = fr.readline()
      continue
    
  return maxAttr, minAttr

至此为止,分类器算是完成,接下去就是用剩下的测试集进行测试,计算分类的准确度,用kNNTest函数实现:

def kNNTest(testSetPath = '', trainSetPath = '', lenOfInstance = 42, startAttr = 0, stopAttr = 40, posOfClass = 41):
  fr = open(testSetPath)
  strOfLine = fr.readline()
  arrayOfLine = numpy.array([0.] * lenOfInstance)
  succeedClassify = 0.0
  failedClassify = 0.0
  
  while strOfLine != '' :
    strOfLine = strOfLine.strip()
    strOfLine = strOfLine.split(';')
    
    abandonOrNot = False
    for i in range(lenOfInstance) :
      if strOfLine[i] == 'RB' :
        arrayOfLine[i] = 1.0
      elif strOfLine[i] == 'NRB' :
        arrayOfLine[i] = 0.0
      elif strOfLine[i] != '?' :             #没有发现缺失值
        arrayOfLine[i] = float(strOfLine[i])     
        abandonOrNot = False
      else :                      #发现缺失值
        abandonOrNot = True
        break
    
    if abandonOrNot == True :
      strOfLine = fr.readline()
      continue
    else :
      inputAttr = numpy.array(arrayOfLine)
      classLabel = kNNClassify(inputAttr, trainSetPath = trainSetPath, lenOfInstance = 42, startAttr = startAttr, 
                   stopAttr = stopAttr, posOfClass = posOfClass)
      if classLabel == arrayOfLine[posOfClass] :
        succeedClassify = succeedClassify + 1.0
      else :
        failedClassify = failedClassify + 1.0
      strOfLine = fr.readline()
      
  accuracy = succeedClassify / (succeedClassify + failedClassify)
  return accuracy

最后,进行测试:

accuracy = kNN.kNNTest(testSetPath = 'D:\\python_project\\test_data\\QSAR-biodegradation-Data-Set\\biodeg-test.csv', 
            trainSetPath = 'D:\\python_project\\test_data\\QSAR-biodegradation-Data-Set\\biodeg-train.csv', 
            startAttr = 0, stopAttr = 40)
print '分类准确率为:',accuracy

输出结果为:

分类准确率为: 0.847058823529

可见用kNN这种分类器的对这个数据集的分类效果其实还是比较一般的,而且根据我的测试,分类函数kNNClassify中numOfRefSamples(其实就是k-近邻中k)的取值对分类准确度也有明显的影响,大概在k取5的时候,分类效果比较理想,并不是越大越好。下面谈谈我对这个问题的理解:

首先,kNN算法是一种简单的分类算法,不需要任何训练过程,在样本数据的结构比较简单边界比较明显的时候,它的分类效果是比较理想的,比如:

使用python实现kNN分类算法

当k的取值比较大的时候,在某些复杂的边界下会出现很差的分类效果,比如下面的情况下很多蓝色的类会被分到红色中,所以要用比较小的k才会有相对较好的分类效果:

使用python实现kNN分类算法

但是当k取得太小也会使分类效果变差,比如当不同类的样本数据之间边界不明显,存在交叉的时候,比如:

使用python实现kNN分类算法

总的来说,kNN分类算法是一种比较原始直观的分类算法,对某些简单的情况有比较好的分类效果,并且不需要训练模型。但是它的缺点是分类过程的运算复杂度很高,而且当样本数据的结构比较复杂的时候,它的分类效果不理想。用kNN算法对本次实验中的数据集的分类效果也比较一般,不过我试过其它更简单一些的数据集,确实还是会有不错的分类准确性的,这里就不赘述了。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中apply函数的用法实例教程
Jul 31 Python
python Socket之客户端和服务端握手详解
Sep 18 Python
python:接口间数据传递与调用方法
Dec 17 Python
python处理excel绘制雷达图
Oct 18 Python
Python笔记之代理模式
Nov 20 Python
opencv3/C++实现视频背景去除建模(BSM)
Dec 11 Python
VS2019+python3.7+opencv4.1+tensorflow1.13配置详解
Apr 16 Python
Visual Studio code 配置Python开发环境
Sep 11 Python
Python基于staticmethod装饰器标示静态方法
Oct 17 Python
python中使用np.delete()的实例方法
Feb 01 Python
使用Pytorch训练two-head网络的操作
May 28 Python
Python 键盘事件详解
Nov 11 Python
Python生成验证码、计算具体日期是一年中的第几天实例代码详解
Oct 16 #Python
python可视化实现KNN算法
Oct 16 #Python
python实现KNN分类算法
Oct 16 #Python
python子线程退出及线程退出控制的代码
Oct 16 #Python
python Pillow图像处理方法汇总
Oct 16 #Python
win10环境下配置vscode python开发环境的教程详解
Oct 16 #Python
500行代码使用python写个微信小游戏飞机大战游戏
Oct 16 #Python
You might like
给海燕B411配件机起死回生配上件
2021/03/02 无线电
各种咖啡的英文名子是什么
2021/03/03 新手入门
Fatal error: Call to undefined function curl_init()解决方法
2010/04/09 PHP
Zend Framework框架实现类似Google搜索分页效果
2016/11/25 PHP
php微信公众号开发之关键词回复
2018/10/20 PHP
PHP中的访问修饰符简单比较
2019/02/02 PHP
Laravel-添加后台模板AdminLte的实现方法
2019/10/08 PHP
代码精简的可以实现元素圆角的js函数
2007/07/21 Javascript
关于B/S判断浏览器断开的问题讨论
2008/10/29 Javascript
关于js new Date() 出现NaN 的分析
2012/10/23 Javascript
Jquery实现网页跳转或用命令打开指定网页的解决方法
2013/07/09 Javascript
js导入导出excel(实例代码)
2013/11/25 Javascript
jQuery Migrate 1.1.0 Released 注意事项
2014/06/14 Javascript
JavaScript基础语法、dom操作树及document对象
2014/12/02 Javascript
第七篇Bootstrap表单布局实例代码详解(三种表单布局)
2016/06/21 Javascript
JS框架之vue.js(深入三:组件1)
2016/09/29 Javascript
vue异步axios获取的数据渲染到页面的方法
2018/08/09 Javascript
如何使用puppet替换文件中的string
2018/12/06 Javascript
layui中的switch开关实现方法
2019/09/03 Javascript
使用flow来规范javascript的变量类型
2019/09/12 Javascript
微信小程序音乐播放器开发
2019/11/20 Javascript
vue-amap根据地址回显地图并mark的操作
2020/11/03 Javascript
vue 图片裁剪上传组件的实现
2020/11/12 Javascript
[02:27]2014DOTA2国际邀请赛 VG赛后采访:更大的挑战在等着我们
2014/07/13 DOTA
python识别文字(基于tesseract)代码实例
2019/08/24 Python
canvas实现手机的手势解锁的步骤详细
2020/03/16 HTML / CSS
马来西亚网上购物平台:ezbuy
2018/02/13 全球购物
英国独特的时尚和生活方式品牌:JOY
2018/03/17 全球购物
公司新年寄语
2014/04/04 职场文书
美德少年事迹材料500字
2014/08/19 职场文书
结婚司仪主持词
2015/06/29 职场文书
2015年统计员个人工作总结
2015/07/23 职场文书
vue点击弹窗自动触发点击事件的解决办法(模拟场景)
2021/05/25 Vue.js
mysql 获取时间方式
2022/03/20 MySQL
Ubuntu安装Mysql+启用远程连接的完整过程
2022/06/21 Servers
javascript中Set、Map、WeakSet、WeakMap区别
2022/12/24 Javascript