利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中分数的相关使用教程
Mar 30 Python
Python设置Socket代理及实现远程摄像头控制的例子
Nov 13 Python
python+selenium实现京东自动登录及秒杀功能
Nov 18 Python
Python基于递归实现电话号码映射功能示例
Apr 13 Python
Python带动态参数功能的sqlite工具类
May 26 Python
对python 匹配字符串开头和结尾的方法详解
Oct 27 Python
图文详解python安装Scrapy框架步骤
May 20 Python
python GUI库图形界面开发之PyQt5工具栏控件QToolBar的详细使用方法与实例
Feb 28 Python
基于python 等频分箱qcut问题的解决
Mar 03 Python
Python图像阈值化处理及算法比对实例解析
Jun 19 Python
Python reques接口测试框架实现代码
Jul 28 Python
Django分页器的用法你都了解吗
May 26 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
Win2000+Apache+MySql+PHP4+PERL安装使用小结
2006/10/09 PHP
PHP 图片上传代码
2011/09/13 PHP
php实例分享之html转为rtf格式
2014/06/02 PHP
php静态文件生成类实例分析
2015/01/03 PHP
php输出xml属性的方法
2015/03/19 PHP
php自动给网址加上链接的方法
2015/06/02 PHP
JavaScript作用域链示例分享
2014/05/27 Javascript
JS截取与分割字符串常用技巧总结
2015/11/10 Javascript
WebView启动支付宝客户端支付失败的问题小结
2017/01/11 Javascript
Bootstrap 表单验证formValidation 实现表单动态验证功能
2017/05/17 Javascript
JS写谷歌浏览器chrome的外挂实例
2018/01/11 Javascript
vue cli2.0单页面title修改方法
2018/06/07 Javascript
使用ECharts实现状态区间图
2018/10/25 Javascript
JavaScript怎样在删除前添加确认弹出框?
2019/05/27 Javascript
Python3实现的腾讯微博自动发帖小工具
2013/11/11 Python
介绍Python中内置的itertools模块
2015/04/29 Python
简单了解python元组tuple相关原理
2019/12/02 Python
pandas中read_csv、rolling、expanding用法详解
2020/04/21 Python
Python调用C语言程序方法解析
2020/07/07 Python
浅析Python 序列化与反序列化
2020/08/05 Python
使用python爬取抖音app视频的实例代码
2020/12/01 Python
Canvas中设置width与height的问题浅析
2018/11/01 HTML / CSS
全球异乡人的跨境社交电商平台:Kouhigh口嗨网
2020/07/24 全球购物
某公司C#程序员面试题笔试题
2014/05/26 面试题
计算机应用毕业生自荐信
2013/10/23 职场文书
企业晚会策划方案
2014/05/29 职场文书
七一建党节演讲稿
2014/09/11 职场文书
庆七一宣传标语
2014/10/08 职场文书
六一晚会主持词开场白
2015/05/28 职场文书
实践论读书笔记
2015/06/29 职场文书
办公用品管理制度
2015/08/04 职场文书
检讨书格式
2019/04/25 职场文书
年终工作总结范文
2019/06/20 职场文书
golang fmt格式“占位符”的实例用法详解
2021/07/04 Golang
springboot+zookeeper实现分布式锁
2022/03/21 Java/Android
tomcat默认最大连接数及相关调整方法
2022/05/06 Servers