利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.x版本中cmp()方法的使用教程
May 14 Python
插入排序_Python与PHP的实现版(推荐)
May 11 Python
python 3.6 tkinter+urllib+json实现火车车次信息查询功能
Dec 20 Python
python学生管理系统代码实现
Apr 05 Python
win10 64bit下python NLTK安装教程
Sep 19 Python
python 读取鼠标点击坐标的实例
Dec 29 Python
Python按钮的响应事件详解
Mar 04 Python
教你如何编写、保存与运行Python程序的方法
Jul 12 Python
Python 在OpenCV里实现仿射变换—坐标变换效果
Aug 30 Python
Python面向对象之私有属性和私有方法应用案例分析
Dec 31 Python
解决运行django程序出错问题 'str'object has no attribute'_meta'
Jul 15 Python
numpy实现RNN原理实现
Mar 02 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
全国FM电台频率大全 - 7 吉林省
2020/03/11 无线电
windows下PHP APACHE MYSQ完整配置
2007/01/02 PHP
兼容PHP和Java的des加密解密代码分享
2014/06/26 PHP
Joomla实现组件中弹出一个模式(modal)窗口的方法
2016/05/04 PHP
PHP加MySQL消息队列深入理解
2021/02/27 PHP
JQuery toggle使用分析
2009/11/16 Javascript
用javascript作一个通用向导说明
2011/08/30 Javascript
js 获取坐标 通过JS得到当前焦点(鼠标)的坐标属性
2013/01/04 Javascript
jquery easyui中treegrid用法的简单实例
2014/02/18 Javascript
JavaScript日期时间与时间戳的转换函数分享
2015/01/31 Javascript
在JavaScript中使用对数Math.log()方法的教程
2015/06/15 Javascript
浅谈Jquery核心函数
2015/06/18 Javascript
基于JavaScript创建动态Dom
2015/12/08 Javascript
利用angularjs1.4制作的简易滑动门效果
2017/02/28 Javascript
js评分组件使用详解
2017/06/06 Javascript
vue 引入公共css文件的简单方法(推荐)
2018/01/20 Javascript
微信小程序实现订单倒计时
2020/11/01 Javascript
JavaScript中layim之整合右键菜单的示例代码
2021/02/06 Javascript
python django事务transaction源码分析详解
2017/03/17 Python
ubuntu安装sublime3并配置python3环境的方法
2018/03/15 Python
Python数据结构之哈夫曼树定义与使用方法示例
2018/04/22 Python
Django Admin中增加导出CSV功能过程解析
2019/09/04 Python
TensorFlow实现指数衰减学习率的方法
2020/02/05 Python
如何基于python把文字图片写入word文档
2020/07/31 Python
15个应该掌握的Jupyter Notebook使用技巧(小结)
2020/09/23 Python
CSS3 box-sizing属性
2009/04/17 HTML / CSS
英国现代家具和装饰网站:PN Home
2018/08/16 全球购物
体育学院毕业生自荐信
2013/11/03 职场文书
学校门卫岗位职责
2014/03/16 职场文书
国际语言毕业生求职信
2014/07/08 职场文书
食品安全汇报材料
2014/08/18 职场文书
公务员群众路线专题民主生活会发言材料
2014/09/17 职场文书
入党群众意见范文
2015/06/02 职场文书
房屋质量投诉书
2015/07/02 职场文书
巾帼建功标兵先进事迹材料
2016/02/29 职场文书
Golang并发工具Singleflight
2022/05/06 Golang