利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的gevent框架的入门教程
Apr 29 Python
浅谈python中的数字类型与处理工具
Aug 02 Python
代码讲解Python对Windows服务进行监控
Feb 11 Python
Numpy数组转置的两种实现方法
Apr 17 Python
python 获取键盘输入,同时有超时的功能示例
Nov 13 Python
Python3爬虫爬取英雄联盟高清桌面壁纸功能示例【基于Scrapy框架】
Dec 05 Python
使用pandas实现连续数据的离散化处理方式(分箱操作)
Nov 22 Python
使用ITK-SNAP进行抠图操作并保存mask的实例
Jul 01 Python
Django 用户认证Auth组件的使用
Nov 30 Python
python爬取企查查企业信息之selenium自动模拟登录企查查
Apr 08 Python
Python基于Tkinter开发一个爬取B站直播弹幕的工具
May 06 Python
Pandas 数据编码的十种方法
Apr 20 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
PHP 5.3.1 安装包 VC9 VC6不同版本的区别是什么
2010/07/04 PHP
codeigniter显示所有脚本执行时间的方法
2015/03/21 PHP
php判断用户是否手机访问代码
2015/06/08 PHP
PHP实现的链式队列结构示例
2017/09/15 PHP
在Laravel中使用DataTables插件的方法
2018/05/29 PHP
利用谷歌地图API获取点与点的距离的js代码
2012/10/11 Javascript
在服务端(Page.Write)调用自定义的JS方法详解
2013/08/09 Javascript
JQuery右键菜单插件ContextMenu使用指南
2014/12/19 Javascript
jQuery实现html元素拖拽
2015/07/21 Javascript
浅析四种常见的Javascript声明循环变量的书写方式
2015/10/14 Javascript
深入理解angularjs过滤器
2016/05/25 Javascript
jquery分隔Url的param方法(推荐)
2016/05/25 Javascript
jquery实现焦点轮播效果
2017/02/23 Javascript
详解nodejs微信公众号开发——6.自定义菜单
2017/04/13 NodeJs
CSS3+JavaScript实现翻页幻灯片效果
2017/06/28 Javascript
p5.js临摹动态图形的方法
2019/10/23 Javascript
Python中使用PIPE操作Linux管道
2015/02/04 Python
python 删除列表里所有空格项的方法总结
2018/04/18 Python
Python3之手动创建迭代器的实例代码
2019/05/22 Python
Python+selenium点击网页上指定坐标的实例
2019/07/05 Python
python中使用while循环的实例
2019/08/05 Python
flask框架url与重定向操作实例详解
2020/01/25 Python
Django后端分离 使用element-ui文件上传方式
2020/07/12 Python
python 如何设置守护进程
2020/10/29 Python
Ubuntu20下的Django安装的方法步骤
2021/01/24 Python
Soft Cotton捷克:来自爱琴海棉花的浴袍
2017/02/01 全球购物
德国童装购物网站:NICKI´S.com
2018/04/20 全球购物
全球最大的在线橄榄球商店:Lovell Rugby
2018/05/20 全球购物
介绍Java的内部类
2012/10/27 面试题
大学生简单自荐信
2013/11/10 职场文书
买房协议书
2014/04/11 职场文书
加入学生会演讲稿
2014/04/24 职场文书
县政府办公室领导班子个人对照检查材料
2014/09/16 职场文书
我们的节日端午节活动总结
2015/02/11 职场文书
html实现随机点名器的示例代码
2021/04/02 Javascript
详解使用内网穿透工具Ngrok代理本地服务
2022/03/31 Servers