利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python转换字符串为摩尔斯电码的方法
Jul 06 Python
Python中enumerate()函数编写更Pythonic的循环
Mar 06 Python
python将txt文件读入为np.array的方法
Oct 30 Python
python-opencv颜色提取分割方法
Dec 08 Python
Django中使用极验Geetest滑动验证码过程解析
Jul 31 Python
Django1.11配合uni-app发起微信支付的实现
Oct 12 Python
python实现人机猜拳小游戏
Feb 03 Python
django的autoreload机制实现
Jun 03 Python
python操作微信自动发消息的实现(微信聊天机器人)
Jul 14 Python
一文详述 Python 中的 property 语法
Sep 01 Python
协程Python 中实现多任务耗资源最小的方式
Oct 19 Python
Python中openpyxl实现vlookup函数的实例
Oct 28 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
PHP新手NOTICE错误常见解决方法
2011/12/07 PHP
基于php常用函数总结(数组,字符串,时间,文件操作)
2013/06/27 PHP
PHP实现RTX发送消息提醒的实例代码
2017/01/03 PHP
在chrome中window.onload事件的一些问题
2010/03/01 Javascript
javascript cookie操作类的实现代码小结附使用方法
2010/06/02 Javascript
js解析与序列化json数据(三)json的解析探讨
2013/02/01 Javascript
JavaScript实现的使用键盘控制人物走动实例
2014/08/27 Javascript
javascript 动态创建表格
2015/01/08 Javascript
JS中产生标识符方式的演变
2015/06/12 Javascript
javascript带回调函数的异步脚本载入方法实例分析
2015/07/02 Javascript
深入探究AngularJs之$scope对象(作用域)
2017/07/20 Javascript
最基础的vue.js双向绑定操作
2017/08/23 Javascript
node.js 模块和其下载资源的镜像设置的方法
2018/09/06 Javascript
jquery.param()实现数组或对象的序列化方法
2018/10/08 jQuery
JS实现头条新闻的经典轮播图效果示例
2019/01/30 Javascript
vue动态路由:路由参数改变,视图不更新问题的解决
2019/11/05 Javascript
python迭代器的使用方法实例
2013/11/21 Python
python使用心得之获得github代码库列表
2014/06/25 Python
python数组循环处理方法
2019/08/26 Python
Python threading的使用方法解析
2019/08/28 Python
np.newaxis 实现为 numpy.ndarray(多维数组)增加一个轴
2019/11/30 Python
python打印文件的前几行或最后几行教程
2020/02/13 Python
详解python变量与数据类型
2020/08/25 Python
anaconda安装pytorch1.7.1和torchvision0.8.2的方法(亲测可用)
2021/02/01 Python
沙特阿拉伯网上购物:Sayidaty Mall
2018/05/06 全球购物
NYX Professional Makeup官方网站:专业彩妆和美容产品
2019/10/29 全球购物
AJAX的全称是什么
2012/11/06 面试题
九州传奇上机题
2014/07/10 面试题
老公给老婆的道歉信
2014/01/10 职场文书
班子群众路线教育实践个人对照检查材料思想汇报
2014/09/30 职场文书
领导参观欢迎词
2015/01/26 职场文书
班主任自我评价范文
2015/03/11 职场文书
毕业论文答辩开场白和答辩技巧
2015/05/27 职场文书
职工培训工作总结
2015/08/10 职场文书
2016秋季小学开学寄语
2015/12/03 职场文书
Java org.w3c.dom.Document 类方法引用报错
2021/08/07 Java/Android