利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简析Python的闭包和装饰器
Feb 26 Python
浅谈Python数据类型之间的转换
Jun 08 Python
python 文件操作api(文件操作函数)
Aug 28 Python
Python实现PS滤镜的万花筒效果示例
Jan 23 Python
对pycharm 修改程序运行所需内存详解
Dec 03 Python
python使用KNN算法识别手写数字
Apr 25 Python
简单了解为什么python函数后有多个括号
Dec 19 Python
pytorch .detach() .detach_() 和 .data用于切断反向传播的实现
Dec 27 Python
Python中bisect的使用方法
Dec 31 Python
Python获取对象属性的几种方式小结
Mar 12 Python
PyTorch中clone()、detach()及相关扩展详解
Dec 09 Python
python性能测试工具locust的使用
Dec 28 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
smarty的保留变量问题
2008/10/23 PHP
php入门学习知识点二 PHP简单的分页过程与原理
2011/07/14 PHP
php将access数据库转换到mysql数据库的方法
2014/12/24 PHP
PHP+Mysql无刷新问答评论系统(源码)
2016/12/20 PHP
Laravel网站打开速度优化的方法汇总
2017/07/16 PHP
centos7上编译安装php7以php-fpm方式连接apache
2018/11/08 PHP
PHP实现无限极分类的两种方式示例【递归和引用方式】
2019/03/25 PHP
bgsound 背景音乐 的一些常用方法及特殊用法小结
2010/05/11 Javascript
在父页面调用子页面的JS方法
2013/09/29 Javascript
使用JS或jQuery模拟鼠标点击a标签事件代码
2014/03/10 Javascript
常用的几段javascript代码分享
2014/03/25 Javascript
Bootstrap CSS布局之按钮
2016/12/17 Javascript
JS使用正则表达式找出最长连续子串长度
2017/10/26 Javascript
使用javascript做在线算法编程
2018/05/25 Javascript
vue input输入框关键字筛选检索列表数据展示
2020/10/26 Javascript
在layui tab控件中载入外部html页面的方法
2019/09/04 Javascript
VUE使用axios调用后台API接口的方法
2020/08/03 Javascript
javascript使用正则表达式实现注册登入校验
2020/09/23 Javascript
keep-alive保持组件状态的方法
2020/12/02 Javascript
[01:11:32]VG vs FNATIC 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
Python多线程编程(三):threading.Thread类的重要函数和方法
2015/04/05 Python
使用rst2pdf实现将sphinx生成PDF
2016/06/07 Python
Python的Tornado框架实现图片上传及图片大小修改功能
2016/06/30 Python
详解Python进程间通信之命名管道
2017/08/28 Python
Windows 7下Python Web环境搭建图文教程
2018/03/20 Python
centos6.8安装python3.7无法import _ssl的解决方法
2018/09/17 Python
python文本数据处理学习笔记详解
2019/06/17 Python
Django模板标签{% for %}循环,获取制定条数据实例
2020/05/14 Python
Python命名空间及作用域原理实例解析
2020/08/12 Python
css3 border旋转时的动画应用
2016/01/22 HTML / CSS
HTML5图片预览实例分享
2014/06/04 HTML / CSS
HTML5中图片之间的缝隙完美解决方法
2017/07/07 HTML / CSS
全球领先的中国制造商品在线批发平台:DHgate
2020/01/28 全球购物
在使用非全零作为空指针内部表达的机器上, NULL是如何定义
2014/11/09 面试题
试用期工作表现自我评价
2015/03/06 职场文书
土木工程生产实习心得体会
2016/01/22 职场文书