利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用Django框架中select_related和prefetch_related函数对数据库查询优化
Apr 01 Python
Python通过select实现异步IO的方法
Jun 04 Python
Python实现excel转sqlite的方法
Jul 17 Python
django站点管理详解
Dec 12 Python
python使用itchat库实现微信机器人(好友聊天、群聊天)
Jan 04 Python
python3解析库BeautifulSoup4的安装配置与基本用法
Jun 26 Python
在Django中URL正则表达式匹配的方法
Dec 20 Python
keras 特征图可视化实例(中间层)
Jan 24 Python
python GUI库图形界面开发之PyQt5下拉列表框控件QComboBox详细使用方法与实例
Feb 27 Python
Python3.8安装Pygame教程步骤详解
Aug 14 Python
Python安装第三方库攻略(pip和Anaconda)
Oct 15 Python
python之PySide2安装使用及QT Designer UI设计案例教程
Jul 26 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
PHP的面试题集
2006/11/19 PHP
浅析PHP文件下载原理
2014/12/25 PHP
php+mysqli实现批量替换数据库表前缀的方法
2014/12/29 PHP
PHP 命名空间和自动加载原理与用法实例分析
2020/04/29 PHP
javascript 利用Image对象实现的埋点(某处的点击数)统计
2012/12/28 Javascript
jQuery插件实现表格隔行换色且感应鼠标高亮行变色
2013/09/22 Javascript
jquery实现checkbox全选全不选的简单实例
2013/12/31 Javascript
jQuery简易图片放大特效示例代码
2014/06/09 Javascript
AngularJS 使用 UI Router 实现表单向导
2016/01/29 Javascript
Javascript+CSS3实现进度条效果
2016/10/28 Javascript
有趣的bootstrap走动进度条
2016/12/01 Javascript
微信小程序 登录的简单实现
2017/04/19 Javascript
解决JS外部文件中文注释出现乱码问题
2017/07/09 Javascript
react路由配置方式详解
2017/08/07 Javascript
jQuery实现菜单栏导航效果
2017/08/15 jQuery
今天,小程序正式支持 SVG
2019/04/20 Javascript
JavaScript RegExp 对象用法详解
2019/09/24 Javascript
nuxt.js 在middleware(中间件)中实现路由鉴权操作
2020/11/06 Javascript
使用Python读写及压缩和解压缩文件的示例
2016/07/08 Python
python3.6.3+opencv3.3.0实现动态人脸捕获
2018/05/25 Python
python中scikit-learn机器代码实例
2018/08/05 Python
python3应用windows api对后台程序窗口及桌面截图并保存的方法
2019/08/27 Python
python读取word 中指定位置的表格及表格数据
2019/10/23 Python
Python: glob匹配文件的操作
2020/12/11 Python
利用css3画个同心圆示例代码
2017/07/03 HTML / CSS
基于canvas的骨骼动画的示例代码
2018/06/12 HTML / CSS
加拿大奢华时装品牌:Mackage
2018/01/10 全球购物
捷克鲜花配送:Florea.cz
2018/10/29 全球购物
六道php面试题附答案
2014/06/05 面试题
中科软测试工程师面试题
2012/06/16 面试题
中班中秋节活动反思
2014/02/18 职场文书
精彩的广告词
2014/03/19 职场文书
个人借条范本
2015/05/25 职场文书
欠条格式范本
2015/07/03 职场文书
小学六年级毕业感言
2015/07/30 职场文书
Redis可视化客户端小结
2021/06/10 Redis