python实现KNN近邻算法


Posted in Python onDecember 30, 2020

示例:《电影类型分类》

获取数据来源

电影名称 打斗次数 接吻次数 电影类型
California Man 3 104 Romance
He's Not Really into Dudes 8 95 Romance
Beautiful Woman 1 81 Romance
Kevin Longblade 111 15 Action
Roob Slayer 3000 99 2 Action
Amped II 88 10 Action
Unknown 18 90 unknown

数据显示:肉眼判断电影类型unknown是什么

from matplotlib import pyplot as plt
​
# 用来正常显示中文标签
plt.rcParams["font.sans-serif"] = ["SimHei"]
# 电影名称
names = ["California Man", "He's Not Really into Dudes", "Beautiful Woman",
   "Kevin Longblade", "Robo Slayer 3000", "Amped II", "Unknown"]
# 类型标签
labels = ["Romance", "Romance", "Romance", "Action", "Action", "Action", "Unknown"]
colors = ["darkblue", "red", "green"]
colorDict = {label: color for (label, color) in zip(set(labels), colors)}
print(colorDict)
# 打斗次数,接吻次数
X = [3, 8, 1, 111, 99, 88, 18]
Y = [104, 95, 81, 15, 2, 10, 88]
​
plt.title("通过打斗次数和接吻次数判断电影类型", fontsize=18)
plt.xlabel("电影中打斗镜头出现的次数", fontsize=16)
plt.ylabel("电影中接吻镜头出现的次数", fontsize=16)
​
# 绘制数据
for i in range(len(X)):
 # 散点图绘制
 plt.scatter(X[i], Y[i], color=colorDict[labels[i]])
​
# 每个点增加描述信息
for i in range(0, 7):
 plt.text(X[i]+2, Y[i]-1, names[i], fontsize=14)
​
plt.show()

问题分析:根据已知信息分析电影类型unknown是什么

核心思想:

未标记样本的类别由距离其最近的K个邻居的类别决定

距离度量:

一般距离计算使用欧式距离(用勾股定理计算距离),也可以采用曼哈顿距离(水平上和垂直上的距离之和)、余弦值和相似度(这是距离的另一种表达方式)。相比于上述距离,马氏距离更为精确,因为它能考虑很多因素,比如单位,由于在求协方差矩阵逆矩阵的过程中,可能不存在,而且若碰见3维及3维以上,求解过程中极其复杂,故可不使用马氏距离

知识扩展

  • 马氏距离概念:表示数据的协方差距离
  • 方差:数据集中各个点到均值点的距离的平方的平均值
  • 标准差:方差的开方
  • 协方差cov(x, y):E表示均值,D表示方差,x,y表示不同的数据集,xy表示数据集元素对应乘积组成数据集

cov(x, y) = E(xy) - E(x)*E(y)

cov(x, x) = D(x)

cov(x1+x2, y) = cov(x1, y) + cov(x2, y)

cov(ax, by) = abcov(x, y)

  • 协方差矩阵:根据维度组成的矩阵,假设有三个维度,a,b,c

∑ij = [cov(a, a) cov(a, b) cov(a, c) cov(b, a) cov(b,b) cov(b, c) cov(c, a) cov(c, b) cov(c, c)]

算法实现:欧氏距离

编码实现

# 自定义实现 mytest1.py
import numpy as np
​
# 创建数据集
def createDataSet():
 features = np.array([[3, 104], [8, 95], [1, 81], [111, 15],
       [99, 2], [88, 10]])
 labels = ["Romance", "Romance", "Romance", "Action", "Action", "Action"]
 return features, labels
​
def knnClassify(testFeature, trainingSet, labels, k):
 """
 KNN算法实现,采用欧式距离
 :param testFeature: 测试数据集,ndarray类型,一维数组
 :param trainingSet: 训练数据集,ndarray类型,二维数组
 :param labels: 训练集对应标签,ndarray类型,一维数组
 :param k: k值,int类型
 :return: 预测结果,类型与标签中元素一致
 """
 dataSetsize = trainingSet.shape[0]
 """
 构建一个由dataSet[i] - testFeature的新的数据集diffMat
 diffMat中的每个元素都是dataSet中每个特征与testFeature的差值(欧式距离中差)
 """
 testFeatureArray = np.tile(testFeature, (dataSetsize, 1))
 diffMat = testFeatureArray - trainingSet
 # 对每个差值求平方
 sqDiffMat = diffMat ** 2
 # 计算dataSet中每个属性与testFeature的差的平方的和
 sqDistances = sqDiffMat.sum(axis=1)
 # 计算每个feature与testFeature之间的欧式距离
 distances = sqDistances ** 0.5
​
 """
 排序,按照从小到大的顺序记录distances中各个数据的位置
 如distance = [5, 9, 0, 2]
 则sortedStance = [2, 3, 0, 1]
 """
 sortedDistances = distances.argsort()
​
 # 选择距离最小的k个点
 classCount = {}
 for i in range(k):
  voteiLabel = labels[list(sortedDistances).index(i)]
  classCount[voteiLabel] = classCount.get(voteiLabel, 0) + 1
 # 对k个结果进行统计、排序,选取最终结果,将字典按照value值从大到小排序
 sortedclassCount = sorted(classCount.items(), key=lambda x: x[1], reverse=True)
 return sortedclassCount[0][0]
​
testFeature = np.array([100, 200])
features, labels = createDataSet()
res = knnClassify(testFeature, features, labels, 3)
print(res)
# 使用python包实现 mytest2.py
from sklearn.neighbors import KNeighborsClassifier
from .mytest1 import createDataSet
​
features, labels = createDataSet()
k = 5
clf = KNeighborsClassifier(k_neighbors=k)
clf.fit(features, labels)
​
# 样本值
my_sample = [[18, 90]]
res = clf.predict(my_sample)
print(res)

示例:《交友网站匹配效果预测》

数据来源:略

数据显示

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
​
# 数据加载
def loadDatingData(file):
 datingData = pd.read_table(file, header=None)
 datingData.columns = ["FlightDistance", "PlaytimePreweek", "IcecreamCostPreweek", "label"]
 datingTrainData = np.array(datingData[["FlightDistance", "PlaytimePreweek", "IcecreamCostPreweek"]])
 datingTrainLabel = np.array(datingData["label"])
 return datingData, datingTrainData, datingTrainLabel
​
# 3D图显示数据
def dataView3D(datingTrainData, datingTrainLabel):
 plt.figure(1, figsize=(8, 3))
 plt.subplot(111, projection="3d")
 plt.scatter(np.array([datingTrainData[x][0]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "smallDoses"]),
    np.array([datingTrainData[x][1]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "smallDoses"]),
    np.array([datingTrainData[x][2]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "smallDoses"]), c="red")
 plt.scatter(np.array([datingTrainData[x][0]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "didntLike"]),
    np.array([datingTrainData[x][1]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "didntLike"]),
    np.array([datingTrainData[x][2]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "didntLike"]), c="green")
 plt.scatter(np.array([datingTrainData[x][0]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "largeDoses"]),
    np.array([datingTrainData[x][1]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "largeDoses"]),
    np.array([datingTrainData[x][2]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "largeDoses"]), c="blue")
 plt.xlabel("飞行里程数", fontsize=16)
 plt.ylabel("视频游戏耗时百分比", fontsize=16)
 plt.clabel("冰淇凌消耗", fontsize=16)
 plt.show()
 
datingData, datingTrainData, datingTrainLabel = loadDatingData(FILEPATH1)
datingView3D(datingTrainData, datingTrainLabel)

问题分析:抽取数据集的前10%在数据集的后90%进行测试

编码实现

# 自定义方法实现
import pandas as pd
import numpy as np
​
# 数据加载
def loadDatingData(file):
 datingData = pd.read_table(file, header=None)
 datingData.columns = ["FlightDistance", "PlaytimePreweek", "IcecreamCostPreweek", "label"]
 datingTrainData = np.array(datingData[["FlightDistance", "PlaytimePreweek", "IcecreamCostPreweek"]])
 datingTrainLabel = np.array(datingData["label"])
 return datingData, datingTrainData, datingTrainLabel
​
# 数据归一化
def autoNorm(datingTrainData):
 # 获取数据集每一列的最值
 minValues, maxValues = datingTrainData.min(0), datingTrainData.max(0)
 diffValues = maxValues - minValues
 
 # 定义形状和datingTrainData相似的最小值矩阵和差值矩阵
 m = datingTrainData.shape(0)
 minValuesData = np.tile(minValues, (m, 1))
 diffValuesData = np.tile(diffValues, (m, 1))
 normValuesData = (datingTrainData-minValuesData)/diffValuesData
 return normValuesData
​
# 核心算法实现
def KNNClassifier(testData, trainData, trainLabel, k):
 m = trainData.shape(0)
 testDataArray = np.tile(testData, (m, 1))
 diffDataArray = (testDataArray - trainData) ** 2
 sumDataArray = diffDataArray.sum(axis=1) ** 0.5
 # 对结果进行排序
 sumDataSortedArray = sumDataArray.argsort()
 
 classCount = {}
 for i in range(k):
  labelName = trainLabel[list(sumDataSortedArray).index(i)]
  classCount[labelName] = classCount.get(labelName, 0)+1
 classCount = sorted(classCount.items(), key=lambda x: x[1], reversed=True)
 return classCount[0][0]
 
​
# 数据测试
def datingTest(file):
 datingData, datingTrainData, datingTrainLabel = loadDatingData(file)
 normValuesData = autoNorm(datingTrainData)
 
 
 errorCount = 0
 ratio = 0.10
 total = datingTrainData.shape(0)
 numberTest = int(total * ratio)
 for i in range(numberTest):
  res = KNNClassifier(normValuesData[i], normValuesData[numberTest:m], datingTrainLabel, 5)
  if res != datingTrainLabel[i]:
   errorCount += 1
 print("The total error rate is : {}\n".format(error/float(numberTest)))
​
if __name__ == "__main__":
 FILEPATH = "./datingTestSet1.txt"
 datingTest(FILEPATH)
# python 第三方包实现
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
​
if __name__ == "__main__":
 FILEPATH = "./datingTestSet1.txt"
 datingData, datingTrainData, datingTrainLabel = loadDatingData(FILEPATH)
 normValuesData = autoNorm(datingTrainData)
 errorCount = 0
 ratio = 0.10
 total = normValuesData.shape[0]
 numberTest = int(total * ratio)
 
 k = 5
 clf = KNeighborsClassifier(n_neighbors=k)
 clf.fit(normValuesData[numberTest:total], datingTrainLabel[numberTest:total])
 
 for i in range(numberTest):
  res = clf.predict(normValuesData[i].reshape(1, -1))
  if res != datingTrainLabel[i]:
   errorCount += 1
 print("The total error rate is : {}\n".format(errorCount/float(numberTest)))

以上就是python实现KNN近邻算法的详细内容,更多关于python实现KNN近邻算法的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python的Django框架使用入门指引
Apr 15 Python
在Linux中通过Python脚本访问mdb数据库的方法
May 06 Python
Python实现读取及写入csv文件的方法示例
Jan 12 Python
Python Cookie 读取和保存方法
Dec 28 Python
VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的方法详解
Jul 01 Python
对python中基于tcp协议的通信(数据传输)实例讲解
Jul 22 Python
深入解析神经网络从原理到实现
Jul 26 Python
Python while循环使用else语句代码实例
Feb 07 Python
python应用Axes3D绘图(批量梯度下降算法)
Mar 25 Python
keras的siamese(孪生网络)实现案例
Jun 12 Python
详解vscode实现远程linux服务器上Python开发
Nov 10 Python
python反爬虫方法的优缺点分析
Nov 25 Python
python 实现逻辑回归
Dec 30 #Python
Python 随机按键模拟2小时
Dec 30 #Python
Python的scikit-image模块实例讲解
Dec 30 #Python
用Python实现职工信息管理系统
Dec 30 #Python
python实现双人五子棋(终端版)
Dec 30 #Python
pandas 数据类型转换的实现
Dec 29 #Python
python中xlutils库用法浅析
Dec 29 #Python
You might like
php 表单数据的获取代码
2009/03/10 PHP
php+mysql事务rollback&commit示例
2010/02/08 PHP
json的键名为数字时的调用方式(示例代码)
2013/11/15 PHP
php输出金字塔的2种实现方法
2014/12/16 PHP
php微信公众号开发之现金红包
2018/04/16 PHP
php7连接MySQL实现简易查询程序的方法
2020/10/13 PHP
Jquery知识点二 jquery下对数组的操作
2011/01/15 Javascript
js弹出确认是否删除对话框
2014/03/27 Javascript
jQuery插件分享之分页插件jqPagination
2014/06/06 Javascript
javascript中HTMLDOM操作详解
2014/12/11 Javascript
Select2.js下拉框使用小结
2016/10/24 Javascript
Bootstrap源码解读表单(2)
2016/12/22 Javascript
微信小程序手势操作之单触摸点与多触摸点
2017/03/10 Javascript
js字符串与Unicode编码互相转换
2017/05/17 Javascript
JS库之Particles.js中文开发手册及参数详解
2017/09/13 Javascript
angularJS实现动态添加,删除div方法
2018/02/27 Javascript
全站最详细的Vuex教程
2018/04/13 Javascript
webpack的tree shaking的实现方法
2019/09/18 Javascript
[03:12]完美世界DOTA2联赛PWL DAY6集锦
2020/11/05 DOTA
深度剖析使用python抓取网页正文的源码
2014/06/11 Python
python爬虫常用的模块分析
2014/08/29 Python
在Linux下调试Python代码的各种方法
2015/04/17 Python
Python中条件判断语句的简单使用方法
2015/08/21 Python
python开发环境PyScripter中文乱码问题解决方案
2016/09/11 Python
python 读取竖线分隔符的文本方法
2018/12/20 Python
Python Numpy库安装与基本操作示例
2019/01/08 Python
Django实现学员管理系统
2019/02/26 Python
CSS3只让背景图片旋转180度的实现示例
2021/03/09 HTML / CSS
探索HTML5本地存储功能运用技巧
2016/03/02 HTML / CSS
Stührling手表官方网站:男女高品质时尚手表的领先零售商
2021/01/07 全球购物
生产主管岗位职责
2013/11/10 职场文书
旅游管理专业生自荐信范文
2014/01/02 职场文书
2015年毕业生自我鉴定模板
2014/09/19 职场文书
《比尾巴》教学反思
2016/02/24 职场文书
Python中的xlrd模块使用整理
2021/06/15 Python
使用refresh_token实现无感刷新页面
2022/04/26 Javascript