用Python实现KNN分类算法


Posted in Python onDecember 22, 2017

本文实例为大家分享了Python KNN分类算法的具体代码,供大家参考,具体内容如下

KNN分类算法应该算得上是机器学习中最简单的分类算法了,所谓KNN即为K-NearestNeighbor(K个最邻近样本节点)。在进行分类之前KNN分类器会读取较多数量带有分类标签的样本数据作为分类的参照数据,当它对类别未知的样本进行分类时,会计算当前样本与所有参照样本的差异大小;该差异大小是通过数据点在样本特征的多维度空间中的距离来进行衡量的,也就是说,如果两个样本点在在其特征数据多维度空间中的距离越近,则这两个样本点之间的差异就越小,这两个样本点属于同一类别的可能性就越大。KNN分类算法利用这一基本的认知,通过计算待预测样本点与参照样本空间中所有的样本的距离,并找到K个距离该样本点最近的参照样本点,统计出这最邻近的K个样本点中占比数量最多的类别,并将该类别作为预测结果。

用Python实现KNN分类算法

KNN的模型十分简单,没有涉及到模型的训练,每一次预测都需要计算该点与所有已知点的距离,因此随着参照样本集的数量增加,KNN分类器的计算开销也会呈比例增加,并且KNN并不适合数量很少的样本集。并且KNN提出之后,后续很多人提出了很多改进的算法,分别从提高算法速率和提高算法准确率两个方向,但是都是基于“距离越近,相似的可能性越大”的原则。这里利用Python实现了KNN最原始版本的算法,数据集使用的是机器学习课程中使用得非常多的莺尾花数据集,同时我在原数据集的基础上向数据集中添加了少量的噪声数据,测试KNN算法的鲁棒性。

数据集用得是莺尾花数据集,下载地址。

用Python实现KNN分类算法

数据集包含90个数据(训练集),分为2类,每类45个数据,每个数据4个属性 

Sepal.Length(花萼长度),单位是cm;
Sepal.Width(花萼宽度),单位是cm;
Petal.Length(花瓣长度),单位是cm;
Petal.Width(花瓣宽度),单位是cm;

分类种类: Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾)
之前主打C++,近来才学的Python,今天想拿实现KNN来练练手,下面上代码:

#coding=utf-8
import math
#定义鸢尾花的数据类
class Iris:
 data=[]
 label=[]
 pass
#定义一个读取莺尾花数据集的函数
def load_dataset(filename="Iris_train.txt"):
 f=open(filename)
 line=f.readline().strip()
 propty=line.split(',')#属性名
 dataset=[]#保存每一个样本的数据信息
 label=[]#保存样本的标签
 while line:
 line=f.readline().strip()
 if(not line):
 break
 temp=line.split(',')
 content=[]
 for i in temp[0:-1]:
 content.append(float(i))
 dataset.append(content)
 label.append(temp[-1])
 total=Iris()
 total.data=dataset
 total.label=label
 return total#返回数据集
 
#定义一个Knn分类器类
class KnnClassifier:
 def __init__(self,k,type="Euler"):#初始化的时候定义正整数K和距离计算方式
 self.k=k
 self.type=type
 self.dataloaded=False
 def load_traindata(self,traindata):#加载数据集
 self.data=traindata.data
 self.label=traindata.label
 self.label_set=set(traindata.label)
 self.dataloaded=True#是否加载数据集的标记
 
 def Euler_dist(self,x,y):# 欧拉距离计算方法,x、y都是向量
 sum=0
 for i,j in zip(x,y):
 sum+=math.sqrt((i-j)**2)
 return sum
 def Manhattan_dist(self,x,y):#曼哈顿距离计算方法,x、y都是向量
 sum=0
 for i,j in zip(x,y):
 sum+=abs(i-j)
 return sum
 def predict(self,temp):#预测函数,读入一个预测样本的数据,temp是一个向量
 if(not self.dataloaded):#判断是否有训练数据
 print "No train_data load in"
 return
 distance_and_label=[]
 if(self.type=="Euler"):#判断距离计算方式,欧拉距离或者曼哈顿距离
 for i,j in zip(self.data,self.label):
 dist=self.Euler_dist(temp,i)
 distance_and_label.append([dist,j])
 else:
 if(self.type=="Manhattan"):
 for i,j in zip(self.data,self.label):
 dist=self.Manhattan_dist(temp,i)
 distance_and_label.append([dist,j])
 else:
 print "type choice error"
 #获取K个最邻近的样本的距离和类别标签
 neighborhood=sorted(distance_and_label,cmp=lambda x,y : cmp(x[0],y[0]))[0:self.k]
 neighborhood_class=[]
 for i in neighborhood:
 neighborhood_class.append(i[1])
 class_set=set(neighborhood_class)
 neighborhood_class_count=[]
 print "In k nearest neighborhoods:"
 #统计该K个最邻近点中各个类别的个数
 for i in class_set:
 a=neighborhood_class.count(i)
 neighborhood_class_count.append([i,a])
 print "class: ",i," count: ",a
 result=sorted(neighborhood_class_count,cmp=lambda x,y : cmp(x[1],y[1]))[-1][0]
 print "result: ",result
 return result#返回预测的类别
 
if __name__ == '__main__':
 traindata=load_dataset()#training data
 testdata=load_dataset("Iris_test.txt")#testing data
 #新建一个Knn分类器的K为20,默认为欧拉距离计算方式
 kc=KnnClassifier(20)
 kc.load_traindata(traindata)
 predict_result=[]
 #预测测试集testdata中所有待预测样本的结果
 for i,j in zip(testdata.data,testdata.label):
 predict_result.append([i,kc.predict(i),j])
 correct_count=0
 #将预测结果和正确结果进行比对,计算该次预测的准确率
 for i in predict_result:
 if(i[1]==i[2]):
 correct_count+=1
 ratio=float(correct_count)/len(predict_result)
 print "correct predicting ratio",ratio

测试集中11个待测样本点的分类结果:

In k nearest neighborhoods:
class: Iris-setosa count: 20
result: Iris-setosa
In k nearest neighborhoods:
class: Iris-setosa count: 20
result: Iris-setosa
In k nearest neighborhoods:
class: Iris-setosa count: 20
result: Iris-setosa
In k nearest neighborhoods:
class: Iris-setosa count: 20
result: Iris-setosa
In k nearest neighborhoods:
class: Iris-setosa count: 20
result: Iris-setosa
In k nearest neighborhoods:
class: Iris-versicolor count: 20
result: Iris-versicolor
In k nearest neighborhoods:
class: Iris-versicolor count: 20
result: Iris-versicolor
In k nearest neighborhoods:
class: Iris-versicolor count: 20
result: Iris-versicolor
In k nearest neighborhoods:
class: Iris-versicolor count: 20
result: Iris-versicolor
In k nearest neighborhoods:
class: Iris-versicolor count: 20
result: Iris-versicolor
In k nearest neighborhoods:
class: Iris-setosa count: 18
class: Iris-versicolor count: 2
result: Iris-setosa
correct predicting ratio 0.909090909091

KNN中对距离的计算有很多种方法,不同的方法适用于不同的数据集,该代码中只实现了欧拉距离和曼哈顿距离两种计算方式;测试集中的数据是从原数据集中抽离出来的,数据量不是很大,结果并不能很好地体现KNN的性能,所以程序运行结果仅供参考。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现归并排序算法
Jun 05 Python
Python中死锁的形成示例及死锁情况的防止
Jun 14 Python
Android基于TCP和URL协议的网络编程示例【附demo源码下载】
Jan 23 Python
python安装pywin32clipboard的操作方法
Jan 24 Python
pygame实现烟雨蒙蒙下彩虹雨
Nov 11 Python
详解python opencv、scikit-image和PIL图像处理库比较
Dec 26 Python
python字符串替换re.sub()实例解析
Feb 09 Python
python GUI库图形界面开发之PyQt5菜单栏控件QMenuBar的详细使用方法与实例
Feb 28 Python
Python 面向对象静态方法、类方法、属性方法知识点小结
Mar 09 Python
解决Python在导入文件时的FileNotFoundError问题
Apr 10 Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 Python
python程序需要编译吗
Jun 19 Python
Python数据拟合与广义线性回归算法学习
Dec 22 #Python
python 动态加载的实现方法
Dec 22 #Python
Python决策树分类算法学习
Dec 22 #Python
Python之Scrapy爬虫框架安装及简单使用详解
Dec 22 #Python
Python2.7下安装Scrapy框架步骤教程
Dec 22 #Python
Python机器学习之决策树算法
Dec 22 #Python
python+selenium实现登录账户后自动点击的示例
Dec 22 #Python
You might like
第十五节--Zend引擎的发展
2006/11/16 PHP
PHP迭代器实现斐波纳契数列的函数
2013/11/12 PHP
排序算法之PHP版快速排序、冒泡排序
2014/04/09 PHP
JS图片浏览组件PhotoLook的公开属性方法介绍和进阶实例代码
2010/11/09 Javascript
jQuery 获取/设置/删除DOM元素的属性以a元素为例
2014/05/23 Javascript
浅谈javascript中字符串String与数组Array
2014/12/31 Javascript
jQuery实现为控件添加水印文字效果(附源码)
2015/12/02 Javascript
javascript自动恢复文本框点击清除后的默认文本
2016/01/12 Javascript
JavaScript弹出对话框的三种方式
2016/03/23 Javascript
Angular.Js中过滤器filter与自定义过滤器filter实例详解
2017/05/08 Javascript
Nodejs 复制文件/文件夹的方法
2017/08/24 NodeJs
Vue CLI3.0中使用jQuery和Bootstrap的方法
2019/02/28 jQuery
es6数值的扩展方法
2019/03/11 Javascript
微信小程序导入Vant报错VM292:1 thirdScriptError的解决方法
2019/08/01 Javascript
JavaScript实现页面高亮操作提示和蒙板
2021/01/04 Javascript
python实现随机密码字典生成器示例
2014/04/09 Python
盘点提高 Python 代码效率的方法
2014/07/03 Python
跟老齐学Python之玩转字符串(1)
2014/09/14 Python
实例解析Python中的__new__特殊方法
2016/06/02 Python
python3 实现的人人影视网站自动签到
2016/06/19 Python
Python tkinter模块中类继承的三种方式分析
2017/08/08 Python
Python Unittest自动化单元测试框架详解
2018/04/04 Python
pandas修改DataFrame列名的方法
2018/04/08 Python
python爬取网页内容转换为PDF文件
2020/07/28 Python
python如何生成网页验证码
2018/07/28 Python
python输入整条数据分割存入数组的方法
2018/11/13 Python
python实现从本地摄像头和网络摄像头截取图片功能
2019/07/11 Python
在C语言中实现抽象数据类型什么方法最好
2014/06/26 面试题
用Java语言将一个键盘输入的数字转化成中文输出
2013/01/25 面试题
文明礼仪标语
2014/06/13 职场文书
会计岗位说明书
2014/07/29 职场文书
岗位职责范本大全
2015/02/26 职场文书
小学生安全保证书
2015/05/09 职场文书
基于CSS3画一个iPhone
2021/04/21 HTML / CSS
HTML+CSS+JS实现图片的瀑布流布局的示例代码
2021/04/22 HTML / CSS
Python 多线程处理任务实例
2021/11/07 Python