Python实现的KMeans聚类算法实例分析


Posted in Python onDecember 29, 2018

本文实例讲述了Python实现的KMeans聚类算法。分享给大家供大家参考,具体如下:

菜鸟一枚,编程初学者,最近想使用Python3实现几个简单的机器学习分析方法,记录一下自己的学习过程。

关于KMeans算法本身就不做介绍了,下面记录一下自己遇到的问题。

一 、关于初始聚类中心的选取

初始聚类中心的选择一般有:

(1)随机选取

(2)随机选取样本中一个点作为中心点,在通过这个点选取距离其较大的点作为第二个中心点,以此类推。

(3)使用层次聚类等算法更新出初始聚类中心

我一开始是使用numpy随机产生k个聚类中心

Center = np.random.randn(k,n)

但是发现聚类的时候迭代几次以后聚类中心会出现nan,有点搞不清楚怎么回事

所以我分别尝试了:

(1)选择数据集的前K个样本做初始中心点

(2)选择随机K个样本点作为初始聚类中心

发现两者都可以完成聚类,我是用的是iris.csv数据集,在选择前K个样本点做数据集时,迭代次数是固定的,选择随机K个点时,迭代次数和随机种子的选取有关,而且聚类效果也不同,有的随机种子聚类快且好,有的慢且差。

def InitCenter(k,m,x_train):
  #Center = np.random.randn(k,n)
  #Center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
  Center = np.zeros([k,n])         #从样本中随机取k个点做初始聚类中心
  np.random.seed(5)            #设置随机数种子
  for i in range(k):
    x = np.random.randint(m)
    Center[i] = np.array(x_train.iloc[x])
  return Center

二 、关于类间距离的选取

为了简单,我直接采用了欧氏距离,目前还没有尝试其他的距离算法。

def GetDistense(x_train, k, m, Center):
  Distence=[]
  for j in range(k):
    for i in range(m):
      x = np.array(x_train.iloc[i, :])
      a = x.T - Center[j]
      Dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.T - Center)
      Distence.append(Dist)
  Dis_array = np.array(Distence).reshape(k,m)
  return Dis_array

三 、关于终止聚类条件的选取

关于聚类的终止条件有很多选择方法:

(1)迭代一定次数

(2)聚类中心的更新小于某个给定的阈值

(3)类中的样本不再变化

我用的是前两种方法,第一种很简单,但是聚类效果不好控制,针对不同数据集,稳健性也不够。第二种比较合适,稳健性也强。第三种方法我还没有尝试,以后可以试着用一下,可能聚类精度会更高一点。

def KMcluster(x_train,k,n,m,threshold):
  global axis_x, axis_y
  center = InitCenter(k,m,x_train)
  initcenter = center
  centerChanged = True
  t=0
  while centerChanged:
    Dis_array = GetDistense(x_train, k, m, center)
    center ,axis_x,axis_y,axis_z= GetNewCenter(x_train,k,n,Dis_array)
    err = np.linalg.norm(initcenter[-k:] - center)
    print(err)
    t+=1
    plt.figure(1)
    p=plt.subplot(3, 3, t)
    p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
    plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
    p.set_title('Iteration'+ str(t))
    if err < threshold:
      centerChanged = False
    else:
      initcenter = np.concatenate((initcenter, center), axis=0)
  plt.show()
  return center, axis_x, axis_y,axis_z, initcenter

err是本次聚类中心点和上次聚类中心点之间的欧氏距离。

threshold是人为设定的终止聚类的阈值,我个人一般设置为0.1或者0.01。

为了将每次迭代产生的类别显示出来我修改了上述代码,使用matplotlib展示每次迭代的散点图。

下面附上我测试数据时的图,子图设置的个数要根据迭代次数来定。

Python实现的KMeans聚类算法实例分析

我测试了几个数据集,聚类的精度还是可以的。

使用iris数据集分析的结果为:

err of Iteration 1 is 3.11443180281
err of Iteration 2 is 1.27568813621
err of Iteration 3 is 0.198909381512
err of Iteration 4 is 0.0
Final cluster center is  [[ 6.85        3.07368421  5.74210526  2.07105263]
 [ 5.9016129   2.7483871   4.39354839  1.43387097]
 [ 5.006       3.428       1.462       0.246     ]]

最后附上全部代码,错误之处还请多多批评,谢谢。

#encoding:utf-8
"""
  Author:   njulpy
  Version:   1.0
  Data:   2018/04/11
  Project: Using Python to Implement KMeans Clustering Algorithm
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import KMeans
def InitCenter(k,m,x_train):
  #Center = np.random.randn(k,n)
  #Center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
  Center = np.zeros([k,n])         #从样本中随机取k个点做初始聚类中心
  np.random.seed(15)            #设置随机数种子
  for i in range(k):
    x = np.random.randint(m)
    Center[i] = np.array(x_train.iloc[x])
  return Center
def GetDistense(x_train, k, m, Center):
  Distence=[]
  for j in range(k):
    for i in range(m):
      x = np.array(x_train.iloc[i, :])
      a = x.T - Center[j]
      Dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.T - Center)
      Distence.append(Dist)
  Dis_array = np.array(Distence).reshape(k,m)
  return Dis_array
def GetNewCenter(x_train,k,n, Dis_array):
  cen = []
  axisx ,axisy,axisz= [],[],[]
  cls = np.argmin(Dis_array, axis=0)
  for i in range(k):
    train_i=x_train.loc[cls == i]
    xx,yy,zz = list(train_i.iloc[:,1]),list(train_i.iloc[:,2]),list(train_i.iloc[:,3])
    axisx.append(xx)
    axisy.append(yy)
    axisz.append(zz)
    meanC = np.mean(train_i,axis=0)
    cen.append(meanC)
  newcent = np.array(cen).reshape(k,n)
  NewCent=np.nan_to_num(newcent)
  return NewCent,axisx,axisy,axisz
def KMcluster(x_train,k,n,m,threshold):
  global axis_x, axis_y
  center = InitCenter(k,m,x_train)
  initcenter = center
  centerChanged = True
  t=0
  while centerChanged:
    Dis_array = GetDistense(x_train, k, m, center)
    center ,axis_x,axis_y,axis_z= GetNewCenter(x_train,k,n,Dis_array)
    err = np.linalg.norm(initcenter[-k:] - center)
    t+=1
    print('err of Iteration '+str(t),'is',err)
    plt.figure(1)
    p=plt.subplot(2, 3, t)
    p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
    plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
    p.set_title('Iteration'+ str(t))
    if err < threshold:
      centerChanged = False
    else:
      initcenter = np.concatenate((initcenter, center), axis=0)
  plt.show()
  return center, axis_x, axis_y,axis_z, initcenter
if __name__=="__main__":
  #x=pd.read_csv("8.Advertising.csv")  # 两组测试数据
  #x=pd.read_table("14.bipartition.txt")
  x=pd.read_csv("iris.csv")
  x_train=x.iloc[:,1:5]
  m,n = np.shape(x_train)
  k = 3
  threshold = 0.1
  km,ax,ay,az,ddd = KMcluster(x_train, k, n, m, threshold)
  print('Final cluster center is ', km)
  #2-Dplot
  plt.figure(2)
  plt.scatter(km[0,1],km[0,2],c = 'r',s = 550,marker='x')
  plt.scatter(km[1,1],km[1,2],c = 'g',s = 550,marker='x')
  plt.scatter(km[2,1],km[2,2],c = 'b',s = 550,marker='x')
  p1, p2, p3 = plt.scatter(axis_x[0], axis_y[0], c='r'), plt.scatter(axis_x[1], axis_y[1], c='g'), plt.scatter(axis_x[2], axis_y[2], c='b')
  plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
  plt.title('2-D scatter')
  plt.show()
  #3-Dplot
  plt.figure(3)
  TreeD = plt.subplot(111, projection='3d')
  TreeD.scatter(ax[0],ay[0],az[0],c='r')
  TreeD.scatter(ax[1],ay[1],az[1],c='g')
  TreeD.scatter(ax[2],ay[2],az[2],c='b')
  TreeD.set_zlabel('Z') # 坐标轴
  TreeD.set_ylabel('Y')
  TreeD.set_xlabel('X')
  TreeD.set_title('3-D scatter')
  plt.show()

Python实现的KMeans聚类算法实例分析

Python实现的KMeans聚类算法实例分析

附:上述示例中的iris.csv文件点击此处本站下载

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python编写Logistic逻辑回归
Dec 30 Python
Python查找文件中包含中文的行方法
Dec 19 Python
33个Python爬虫项目实战(推荐)
Jul 08 Python
python实现各种插值法(数值分析)
Jul 30 Python
django echarts饼图数据动态加载的实例
Aug 12 Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 Python
python网络爬虫 CrawlSpider使用详解
Sep 27 Python
Python使用Chrome插件实现爬虫过程图解
Jun 09 Python
numpy 矩阵形状调整:拉伸、变成一位数组的实例
Jun 18 Python
Python pip install之SSL异常处理操作
Sep 03 Python
selenium自动化测试入门实战
Dec 21 Python
opencv实现图像平移效果
Mar 24 Python
Python使用pyshp库读取shapefile信息的方法
Dec 29 #Python
Python实现的线性回归算法示例【附csv文件下载】
Dec 29 #Python
Python 确定多项式拟合/回归的阶数实例
Dec 29 #Python
Python 普通最小二乘法(OLS)进行多项式拟合的方法
Dec 29 #Python
Python实现高斯函数的三维显示方法
Dec 29 #Python
Python3 SSH远程连接服务器的方法示例
Dec 29 #Python
使用python绘制3维正态分布图的方法
Dec 29 #Python
You might like
php.ini修改php上传文件大小限制的方法详解
2013/06/17 PHP
PHP中exec与system用法区别分析
2014/09/22 PHP
YII2.0之Activeform表单组件用法实例
2016/01/09 PHP
浅谈PHP的排列组合(如输入a,b,c 输出他们的全部组合)
2017/03/14 PHP
Phpstorm+Xdebug断点调试PHP的方法
2018/05/14 PHP
thinkphp框架无限级栏目的排序功能实现方法示例
2020/03/29 PHP
PHP接口类(interface)的定义、特点和应用示例
2020/05/18 PHP
神奇的代码 通杀各种网站-可随意修改复制页面内容
2008/07/17 Javascript
javascript 复杂的嵌套环境中输出单引号和双引号
2009/05/26 Javascript
javascript XMLHttpRequest对象全面剖析
2010/04/24 Javascript
js函数调用常用方法详解
2012/12/03 Javascript
表单序列化与jq中的serialize使用示例
2014/02/21 Javascript
php的文件上传入门教程(实例讲解)
2014/04/10 Javascript
jQuery实现标题有打字效果的焦点图代码
2015/11/16 Javascript
D3.js实现饼状图的方法详解
2016/09/21 Javascript
微信小程序 富文本转文本实例详解
2016/10/24 Javascript
vue+vuex+axios+echarts画一个动态更新的中国地图的方法
2017/12/19 Javascript
webpack构建的详细流程探底
2018/01/08 Javascript
微信小程序canvas拖拽、截图组件功能
2018/09/04 Javascript
Vue2.X和Vue3.0数据响应原理变化的区别
2019/11/07 Javascript
[01:02:48]2018DOTA2亚洲邀请赛 4.1 小组赛 A组 LGD vs OG
2018/04/02 DOTA
Python对列表中的各项进行关联详解
2017/08/15 Python
itchat和matplotlib的结合使用爬取微信信息的实例
2017/08/25 Python
Python爬虫框架Scrapy基本用法入门教程
2018/07/26 Python
Python中Proxypool库的安装与配置
2018/10/19 Python
python3实现网页版raspberry pi(树莓派)小车控制
2020/02/12 Python
python实现飞机大战项目
2020/03/11 Python
Pycharm快捷键配置详细整理
2020/10/13 Python
伦敦所有西区剧院演出官方票务代理:Theatre Tickets Direct
2017/05/26 全球购物
使用Vue.js和MJML创建响应式电子邮件
2021/03/23 Vue.js
餐厅周年庆活动方案
2014/08/25 职场文书
个人汇报材料范文
2014/12/30 职场文书
社区文明创建工作总结2015
2015/04/21 职场文书
聘任书范文大全
2015/09/21 职场文书
关于感恩老师的古诗句
2019/08/20 职场文书
Nginx 502 bad gateway错误解决的九种方案及原因
2022/08/14 Servers