利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
go和python调用其它程序并得到程序输出
Feb 10 Python
Python运行报错UnicodeDecodeError的解决方法
Jun 07 Python
python基于ID3思想的决策树
Jan 03 Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 Python
python+opencv识别图片中的圆形
Mar 25 Python
Python lambda表达式用法实例分析
Dec 25 Python
python实现一个简单的udp通信的示例代码
Feb 01 Python
python里运用私有属性和方法总结
Jul 08 Python
python使用paramiko模块通过ssh2协议对交换机进行配置的方法
Jul 25 Python
详解python UDP 编程
Aug 24 Python
分位数回归模型quantile regeression应用详解及示例教程
Nov 02 Python
python中validators库的使用方法详解
Sep 23 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
Syphon 使用方法
2021/03/03 冲泡冲煮
图片存储与浏览一例(Linux+Apache+PHP+MySQL)
2006/10/09 PHP
php实现cookie加密的方法
2015/03/10 PHP
Zend Framework入门教程之Zend_View组件用法示例
2016/12/09 PHP
php正则提取html图片(img)src地址与任意属性的方法
2017/02/08 PHP
JAVASCRIPT 对象的创建与使用
2021/03/09 Javascript
alixixi runcode.asp的代码不错的应用
2007/08/08 Javascript
jquery checkbox全选、取消全选实现代码
2010/03/05 Javascript
JS打开新窗口的2种方式
2013/04/18 Javascript
JavaScrip实现PHP print_r的数功能(三种方法)
2013/11/12 Javascript
JavaScript中的ArrayBuffer详细介绍
2014/12/08 Javascript
浅谈javascript中onbeforeunload与onunload事件
2015/12/10 Javascript
jQuery实现响应鼠标事件的图片透明效果【附demo源码下载】
2016/06/16 Javascript
vue图片加载与显示默认图片实例代码
2017/03/16 Javascript
vue跨域解决方法
2017/10/15 Javascript
jQuery创建及操作xml格式数据示例
2018/05/26 jQuery
layui select 禁止点击的实现方法
2019/09/05 Javascript
es6函数之尾调用优化实例分析
2020/04/25 Javascript
浅要分析Python程序与C程序的结合使用
2015/04/07 Python
一个基于flask的web应用诞生(1)
2017/04/11 Python
从CentOS安装完成到生成词云python的实例
2017/12/01 Python
Python实现的网页截图功能【PyQt4与selenium组件】
2018/07/12 Python
树莓派安装OpenCV3完整过程的实现
2019/10/10 Python
Python使用sqlite3模块内置数据库
2020/05/07 Python
python中常用的数据结构介绍
2021/01/12 Python
CSS3制作彩色进度条样式的代码示例分享
2016/06/23 HTML / CSS
西铁城美国官方网站:Citizen Watch美国
2019/11/08 全球购物
升职自荐信
2013/11/28 职场文书
七夕情人节促销方案
2014/06/07 职场文书
校园广播稿精选
2014/10/01 职场文书
个人德育工作总结
2015/03/05 职场文书
离职信范文
2015/06/23 职场文书
Python爬虫数据的分类及json数据使用小结
2021/03/29 Python
Django 如何实现文件上传下载
2021/04/08 Python
vue+elementui 实现新增和修改共用一个弹框的完整代码
2021/06/08 Vue.js
微软发布Windows 11今年最大更新22H2(附 ISO 镜像官方下载)
2022/09/23 数码科技