Python实现的KMeans聚类算法实例分析


Posted in Python onDecember 29, 2018

本文实例讲述了Python实现的KMeans聚类算法。分享给大家供大家参考,具体如下:

菜鸟一枚,编程初学者,最近想使用Python3实现几个简单的机器学习分析方法,记录一下自己的学习过程。

关于KMeans算法本身就不做介绍了,下面记录一下自己遇到的问题。

一 、关于初始聚类中心的选取

初始聚类中心的选择一般有:

(1)随机选取

(2)随机选取样本中一个点作为中心点,在通过这个点选取距离其较大的点作为第二个中心点,以此类推。

(3)使用层次聚类等算法更新出初始聚类中心

我一开始是使用numpy随机产生k个聚类中心

Center = np.random.randn(k,n)

但是发现聚类的时候迭代几次以后聚类中心会出现nan,有点搞不清楚怎么回事

所以我分别尝试了:

(1)选择数据集的前K个样本做初始中心点

(2)选择随机K个样本点作为初始聚类中心

发现两者都可以完成聚类,我是用的是iris.csv数据集,在选择前K个样本点做数据集时,迭代次数是固定的,选择随机K个点时,迭代次数和随机种子的选取有关,而且聚类效果也不同,有的随机种子聚类快且好,有的慢且差。

def InitCenter(k,m,x_train):
  #Center = np.random.randn(k,n)
  #Center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
  Center = np.zeros([k,n])         #从样本中随机取k个点做初始聚类中心
  np.random.seed(5)            #设置随机数种子
  for i in range(k):
    x = np.random.randint(m)
    Center[i] = np.array(x_train.iloc[x])
  return Center

二 、关于类间距离的选取

为了简单,我直接采用了欧氏距离,目前还没有尝试其他的距离算法。

def GetDistense(x_train, k, m, Center):
  Distence=[]
  for j in range(k):
    for i in range(m):
      x = np.array(x_train.iloc[i, :])
      a = x.T - Center[j]
      Dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.T - Center)
      Distence.append(Dist)
  Dis_array = np.array(Distence).reshape(k,m)
  return Dis_array

三 、关于终止聚类条件的选取

关于聚类的终止条件有很多选择方法:

(1)迭代一定次数

(2)聚类中心的更新小于某个给定的阈值

(3)类中的样本不再变化

我用的是前两种方法,第一种很简单,但是聚类效果不好控制,针对不同数据集,稳健性也不够。第二种比较合适,稳健性也强。第三种方法我还没有尝试,以后可以试着用一下,可能聚类精度会更高一点。

def KMcluster(x_train,k,n,m,threshold):
  global axis_x, axis_y
  center = InitCenter(k,m,x_train)
  initcenter = center
  centerChanged = True
  t=0
  while centerChanged:
    Dis_array = GetDistense(x_train, k, m, center)
    center ,axis_x,axis_y,axis_z= GetNewCenter(x_train,k,n,Dis_array)
    err = np.linalg.norm(initcenter[-k:] - center)
    print(err)
    t+=1
    plt.figure(1)
    p=plt.subplot(3, 3, t)
    p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
    plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
    p.set_title('Iteration'+ str(t))
    if err < threshold:
      centerChanged = False
    else:
      initcenter = np.concatenate((initcenter, center), axis=0)
  plt.show()
  return center, axis_x, axis_y,axis_z, initcenter

err是本次聚类中心点和上次聚类中心点之间的欧氏距离。

threshold是人为设定的终止聚类的阈值,我个人一般设置为0.1或者0.01。

为了将每次迭代产生的类别显示出来我修改了上述代码,使用matplotlib展示每次迭代的散点图。

下面附上我测试数据时的图,子图设置的个数要根据迭代次数来定。

Python实现的KMeans聚类算法实例分析

我测试了几个数据集,聚类的精度还是可以的。

使用iris数据集分析的结果为:

err of Iteration 1 is 3.11443180281
err of Iteration 2 is 1.27568813621
err of Iteration 3 is 0.198909381512
err of Iteration 4 is 0.0
Final cluster center is  [[ 6.85        3.07368421  5.74210526  2.07105263]
 [ 5.9016129   2.7483871   4.39354839  1.43387097]
 [ 5.006       3.428       1.462       0.246     ]]

最后附上全部代码,错误之处还请多多批评,谢谢。

#encoding:utf-8
"""
  Author:   njulpy
  Version:   1.0
  Data:   2018/04/11
  Project: Using Python to Implement KMeans Clustering Algorithm
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import KMeans
def InitCenter(k,m,x_train):
  #Center = np.random.randn(k,n)
  #Center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
  Center = np.zeros([k,n])         #从样本中随机取k个点做初始聚类中心
  np.random.seed(15)            #设置随机数种子
  for i in range(k):
    x = np.random.randint(m)
    Center[i] = np.array(x_train.iloc[x])
  return Center
def GetDistense(x_train, k, m, Center):
  Distence=[]
  for j in range(k):
    for i in range(m):
      x = np.array(x_train.iloc[i, :])
      a = x.T - Center[j]
      Dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.T - Center)
      Distence.append(Dist)
  Dis_array = np.array(Distence).reshape(k,m)
  return Dis_array
def GetNewCenter(x_train,k,n, Dis_array):
  cen = []
  axisx ,axisy,axisz= [],[],[]
  cls = np.argmin(Dis_array, axis=0)
  for i in range(k):
    train_i=x_train.loc[cls == i]
    xx,yy,zz = list(train_i.iloc[:,1]),list(train_i.iloc[:,2]),list(train_i.iloc[:,3])
    axisx.append(xx)
    axisy.append(yy)
    axisz.append(zz)
    meanC = np.mean(train_i,axis=0)
    cen.append(meanC)
  newcent = np.array(cen).reshape(k,n)
  NewCent=np.nan_to_num(newcent)
  return NewCent,axisx,axisy,axisz
def KMcluster(x_train,k,n,m,threshold):
  global axis_x, axis_y
  center = InitCenter(k,m,x_train)
  initcenter = center
  centerChanged = True
  t=0
  while centerChanged:
    Dis_array = GetDistense(x_train, k, m, center)
    center ,axis_x,axis_y,axis_z= GetNewCenter(x_train,k,n,Dis_array)
    err = np.linalg.norm(initcenter[-k:] - center)
    t+=1
    print('err of Iteration '+str(t),'is',err)
    plt.figure(1)
    p=plt.subplot(2, 3, t)
    p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
    plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
    p.set_title('Iteration'+ str(t))
    if err < threshold:
      centerChanged = False
    else:
      initcenter = np.concatenate((initcenter, center), axis=0)
  plt.show()
  return center, axis_x, axis_y,axis_z, initcenter
if __name__=="__main__":
  #x=pd.read_csv("8.Advertising.csv")  # 两组测试数据
  #x=pd.read_table("14.bipartition.txt")
  x=pd.read_csv("iris.csv")
  x_train=x.iloc[:,1:5]
  m,n = np.shape(x_train)
  k = 3
  threshold = 0.1
  km,ax,ay,az,ddd = KMcluster(x_train, k, n, m, threshold)
  print('Final cluster center is ', km)
  #2-Dplot
  plt.figure(2)
  plt.scatter(km[0,1],km[0,2],c = 'r',s = 550,marker='x')
  plt.scatter(km[1,1],km[1,2],c = 'g',s = 550,marker='x')
  plt.scatter(km[2,1],km[2,2],c = 'b',s = 550,marker='x')
  p1, p2, p3 = plt.scatter(axis_x[0], axis_y[0], c='r'), plt.scatter(axis_x[1], axis_y[1], c='g'), plt.scatter(axis_x[2], axis_y[2], c='b')
  plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
  plt.title('2-D scatter')
  plt.show()
  #3-Dplot
  plt.figure(3)
  TreeD = plt.subplot(111, projection='3d')
  TreeD.scatter(ax[0],ay[0],az[0],c='r')
  TreeD.scatter(ax[1],ay[1],az[1],c='g')
  TreeD.scatter(ax[2],ay[2],az[2],c='b')
  TreeD.set_zlabel('Z') # 坐标轴
  TreeD.set_ylabel('Y')
  TreeD.set_xlabel('X')
  TreeD.set_title('3-D scatter')
  plt.show()

Python实现的KMeans聚类算法实例分析

Python实现的KMeans聚类算法实例分析

附:上述示例中的iris.csv文件点击此处本站下载

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python3中的json模块使用详解
May 05 Python
详解Python3定时器任务代码
Sep 23 Python
python3 tcp的粘包现象和解决办法解析
Dec 09 Python
TensorFlow加载模型时出错的解决方式
Feb 06 Python
Python多进程编程multiprocessing代码实例
Mar 12 Python
如何配置关联Python 解释器 Anaconda的教程(图解)
Apr 30 Python
.img/.hdr格式转.nii格式的操作
Jul 01 Python
python Protobuf定义消息类型知识点讲解
Mar 02 Python
2021年最新用于图像处理的Python库总结
Jun 15 Python
Django实现drf搜索过滤和排序过滤
Jun 21 Python
5道关于python基础 while循环练习题
Nov 27 Python
python中pd.cut()与pd.qcut()的对比及示例
Jun 16 Python
Python使用pyshp库读取shapefile信息的方法
Dec 29 #Python
Python实现的线性回归算法示例【附csv文件下载】
Dec 29 #Python
Python 确定多项式拟合/回归的阶数实例
Dec 29 #Python
Python 普通最小二乘法(OLS)进行多项式拟合的方法
Dec 29 #Python
Python实现高斯函数的三维显示方法
Dec 29 #Python
Python3 SSH远程连接服务器的方法示例
Dec 29 #Python
使用python绘制3维正态分布图的方法
Dec 29 #Python
You might like
复杂检索数据并分页显示的处理方法
2006/10/09 PHP
ThinkPHP模型详解
2015/07/27 PHP
如何使Chrome控制台支持多行js模式——意外发现
2013/06/13 Javascript
js仿百度贴吧验证码特效实例代码
2014/01/16 Javascript
JS下载文件|无刷新下载文件示例代码
2014/04/17 Javascript
javascript实现的字符串与十六进制表示字符串相互转换方法
2015/07/17 Javascript
模仿password输入框的实现代码
2016/06/07 Javascript
JavaScript中有关一个数组中最大值和最小值及它们的下表的输出的解决办法
2016/07/01 Javascript
js实现带缓动动画的导航栏效果
2017/01/16 Javascript
node.js平台下的mysql数据库配置及连接
2017/03/31 Javascript
浅谈angular4 ng-content 中隐藏的内容
2017/08/18 Javascript
vue引入jq插件的实例讲解
2017/09/12 Javascript
vue.js如何将echarts封装为组件一键使用详解
2017/10/10 Javascript
Jquery $.map使用方法实例详解
2020/09/01 jQuery
详解datagrid使用方法(重要)
2020/11/06 Javascript
对python3中, print横向输出的方法详解
2019/01/28 Python
对django 2.x版本中models.ForeignKey()外键说明介绍
2020/03/30 Python
解决IDEA 的 plugins 搜不到任何的插件问题
2020/05/04 Python
在django中form的label和verbose name的区别说明
2020/05/20 Python
TobyDeals美国:在电子产品上获得最好的优惠和折扣
2019/08/11 全球购物
星空联盟C# .net笔试题
2014/12/05 面试题
应届毕业生应聘自荐信范文
2014/02/26 职场文书
银行服务感言
2014/03/01 职场文书
个人委托书格式
2014/04/04 职场文书
旅游专业毕业生自荐书
2014/06/30 职场文书
个人三严三实对照检查材料思想汇报
2014/09/22 职场文书
教师三严三实对照检查材料
2014/09/25 职场文书
计算机实训报告总结
2014/11/05 职场文书
2014年幼儿园德育工作总结
2014/12/17 职场文书
2014年高中教师工作总结
2014/12/19 职场文书
2014年乡镇纪委工作总结
2014/12/19 职场文书
小学音乐教师个人工作总结
2015/02/05 职场文书
给男朋友的道歉短信
2015/05/12 职场文书
客户答谢会致辞
2015/07/30 职场文书
省级三好学生主要事迹材料
2015/11/03 职场文书
课改心得体会范文
2016/01/25 职场文书