Python实现的KMeans聚类算法实例分析


Posted in Python onDecember 29, 2018

本文实例讲述了Python实现的KMeans聚类算法。分享给大家供大家参考,具体如下:

菜鸟一枚,编程初学者,最近想使用Python3实现几个简单的机器学习分析方法,记录一下自己的学习过程。

关于KMeans算法本身就不做介绍了,下面记录一下自己遇到的问题。

一 、关于初始聚类中心的选取

初始聚类中心的选择一般有:

(1)随机选取

(2)随机选取样本中一个点作为中心点,在通过这个点选取距离其较大的点作为第二个中心点,以此类推。

(3)使用层次聚类等算法更新出初始聚类中心

我一开始是使用numpy随机产生k个聚类中心

Center = np.random.randn(k,n)

但是发现聚类的时候迭代几次以后聚类中心会出现nan,有点搞不清楚怎么回事

所以我分别尝试了:

(1)选择数据集的前K个样本做初始中心点

(2)选择随机K个样本点作为初始聚类中心

发现两者都可以完成聚类,我是用的是iris.csv数据集,在选择前K个样本点做数据集时,迭代次数是固定的,选择随机K个点时,迭代次数和随机种子的选取有关,而且聚类效果也不同,有的随机种子聚类快且好,有的慢且差。

def InitCenter(k,m,x_train):
  #Center = np.random.randn(k,n)
  #Center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
  Center = np.zeros([k,n])         #从样本中随机取k个点做初始聚类中心
  np.random.seed(5)            #设置随机数种子
  for i in range(k):
    x = np.random.randint(m)
    Center[i] = np.array(x_train.iloc[x])
  return Center

二 、关于类间距离的选取

为了简单,我直接采用了欧氏距离,目前还没有尝试其他的距离算法。

def GetDistense(x_train, k, m, Center):
  Distence=[]
  for j in range(k):
    for i in range(m):
      x = np.array(x_train.iloc[i, :])
      a = x.T - Center[j]
      Dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.T - Center)
      Distence.append(Dist)
  Dis_array = np.array(Distence).reshape(k,m)
  return Dis_array

三 、关于终止聚类条件的选取

关于聚类的终止条件有很多选择方法:

(1)迭代一定次数

(2)聚类中心的更新小于某个给定的阈值

(3)类中的样本不再变化

我用的是前两种方法,第一种很简单,但是聚类效果不好控制,针对不同数据集,稳健性也不够。第二种比较合适,稳健性也强。第三种方法我还没有尝试,以后可以试着用一下,可能聚类精度会更高一点。

def KMcluster(x_train,k,n,m,threshold):
  global axis_x, axis_y
  center = InitCenter(k,m,x_train)
  initcenter = center
  centerChanged = True
  t=0
  while centerChanged:
    Dis_array = GetDistense(x_train, k, m, center)
    center ,axis_x,axis_y,axis_z= GetNewCenter(x_train,k,n,Dis_array)
    err = np.linalg.norm(initcenter[-k:] - center)
    print(err)
    t+=1
    plt.figure(1)
    p=plt.subplot(3, 3, t)
    p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
    plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
    p.set_title('Iteration'+ str(t))
    if err < threshold:
      centerChanged = False
    else:
      initcenter = np.concatenate((initcenter, center), axis=0)
  plt.show()
  return center, axis_x, axis_y,axis_z, initcenter

err是本次聚类中心点和上次聚类中心点之间的欧氏距离。

threshold是人为设定的终止聚类的阈值,我个人一般设置为0.1或者0.01。

为了将每次迭代产生的类别显示出来我修改了上述代码,使用matplotlib展示每次迭代的散点图。

下面附上我测试数据时的图,子图设置的个数要根据迭代次数来定。

Python实现的KMeans聚类算法实例分析

我测试了几个数据集,聚类的精度还是可以的。

使用iris数据集分析的结果为:

err of Iteration 1 is 3.11443180281
err of Iteration 2 is 1.27568813621
err of Iteration 3 is 0.198909381512
err of Iteration 4 is 0.0
Final cluster center is  [[ 6.85        3.07368421  5.74210526  2.07105263]
 [ 5.9016129   2.7483871   4.39354839  1.43387097]
 [ 5.006       3.428       1.462       0.246     ]]

最后附上全部代码,错误之处还请多多批评,谢谢。

#encoding:utf-8
"""
  Author:   njulpy
  Version:   1.0
  Data:   2018/04/11
  Project: Using Python to Implement KMeans Clustering Algorithm
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import KMeans
def InitCenter(k,m,x_train):
  #Center = np.random.randn(k,n)
  #Center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
  Center = np.zeros([k,n])         #从样本中随机取k个点做初始聚类中心
  np.random.seed(15)            #设置随机数种子
  for i in range(k):
    x = np.random.randint(m)
    Center[i] = np.array(x_train.iloc[x])
  return Center
def GetDistense(x_train, k, m, Center):
  Distence=[]
  for j in range(k):
    for i in range(m):
      x = np.array(x_train.iloc[i, :])
      a = x.T - Center[j]
      Dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.T - Center)
      Distence.append(Dist)
  Dis_array = np.array(Distence).reshape(k,m)
  return Dis_array
def GetNewCenter(x_train,k,n, Dis_array):
  cen = []
  axisx ,axisy,axisz= [],[],[]
  cls = np.argmin(Dis_array, axis=0)
  for i in range(k):
    train_i=x_train.loc[cls == i]
    xx,yy,zz = list(train_i.iloc[:,1]),list(train_i.iloc[:,2]),list(train_i.iloc[:,3])
    axisx.append(xx)
    axisy.append(yy)
    axisz.append(zz)
    meanC = np.mean(train_i,axis=0)
    cen.append(meanC)
  newcent = np.array(cen).reshape(k,n)
  NewCent=np.nan_to_num(newcent)
  return NewCent,axisx,axisy,axisz
def KMcluster(x_train,k,n,m,threshold):
  global axis_x, axis_y
  center = InitCenter(k,m,x_train)
  initcenter = center
  centerChanged = True
  t=0
  while centerChanged:
    Dis_array = GetDistense(x_train, k, m, center)
    center ,axis_x,axis_y,axis_z= GetNewCenter(x_train,k,n,Dis_array)
    err = np.linalg.norm(initcenter[-k:] - center)
    t+=1
    print('err of Iteration '+str(t),'is',err)
    plt.figure(1)
    p=plt.subplot(2, 3, t)
    p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
    plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
    p.set_title('Iteration'+ str(t))
    if err < threshold:
      centerChanged = False
    else:
      initcenter = np.concatenate((initcenter, center), axis=0)
  plt.show()
  return center, axis_x, axis_y,axis_z, initcenter
if __name__=="__main__":
  #x=pd.read_csv("8.Advertising.csv")  # 两组测试数据
  #x=pd.read_table("14.bipartition.txt")
  x=pd.read_csv("iris.csv")
  x_train=x.iloc[:,1:5]
  m,n = np.shape(x_train)
  k = 3
  threshold = 0.1
  km,ax,ay,az,ddd = KMcluster(x_train, k, n, m, threshold)
  print('Final cluster center is ', km)
  #2-Dplot
  plt.figure(2)
  plt.scatter(km[0,1],km[0,2],c = 'r',s = 550,marker='x')
  plt.scatter(km[1,1],km[1,2],c = 'g',s = 550,marker='x')
  plt.scatter(km[2,1],km[2,2],c = 'b',s = 550,marker='x')
  p1, p2, p3 = plt.scatter(axis_x[0], axis_y[0], c='r'), plt.scatter(axis_x[1], axis_y[1], c='g'), plt.scatter(axis_x[2], axis_y[2], c='b')
  plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
  plt.title('2-D scatter')
  plt.show()
  #3-Dplot
  plt.figure(3)
  TreeD = plt.subplot(111, projection='3d')
  TreeD.scatter(ax[0],ay[0],az[0],c='r')
  TreeD.scatter(ax[1],ay[1],az[1],c='g')
  TreeD.scatter(ax[2],ay[2],az[2],c='b')
  TreeD.set_zlabel('Z') # 坐标轴
  TreeD.set_ylabel('Y')
  TreeD.set_xlabel('X')
  TreeD.set_title('3-D scatter')
  plt.show()

Python实现的KMeans聚类算法实例分析

Python实现的KMeans聚类算法实例分析

附:上述示例中的iris.csv文件点击此处本站下载

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python中的左斜杠、右斜杠(正斜杠和反斜杠)
Aug 30 Python
浅谈对yield的初步理解
May 29 Python
matplotlib在python上绘制3D散点图实例详解
Dec 09 Python
如何用Python合并lmdb文件
Jul 02 Python
python实现反转部分单向链表
Sep 27 Python
详解python解压压缩包的五种方法
Jul 05 Python
Spring实战之使用util:命名空间简化配置操作示例
Dec 09 Python
Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解
Feb 11 Python
python_mask_array的用法
Feb 18 Python
Python hashlib和hmac模块使用方法解析
Dec 08 Python
使用pandas读取表格数据并进行单行数据拼接的详细教程
Mar 03 Python
python plt.plot bar 如何设置绘图尺寸大小
Jun 01 Python
Python使用pyshp库读取shapefile信息的方法
Dec 29 #Python
Python实现的线性回归算法示例【附csv文件下载】
Dec 29 #Python
Python 确定多项式拟合/回归的阶数实例
Dec 29 #Python
Python 普通最小二乘法(OLS)进行多项式拟合的方法
Dec 29 #Python
Python实现高斯函数的三维显示方法
Dec 29 #Python
Python3 SSH远程连接服务器的方法示例
Dec 29 #Python
使用python绘制3维正态分布图的方法
Dec 29 #Python
You might like
phpmyadmin显示utf8_general_ci中文乱码的问题终级篇
2013/04/08 PHP
php实现批量压缩图片文件大小的脚本
2014/07/04 PHP
php从数组中随机选择若干不重复元素的方法
2015/03/14 PHP
php验证码生成代码
2015/11/11 PHP
PHP设计模式之工厂模式与单例模式
2016/09/28 PHP
PHP实现中国公民身份证号码有效性验证示例代码
2017/05/03 PHP
PHP使用Redis实现防止大并发下二次写入的方法
2017/10/09 PHP
thinkPHP框架乐观锁和悲观锁实例分析
2019/10/30 PHP
浅谈angularjs依赖服务注入写法的注意点
2017/04/24 Javascript
用vue和node写的简易购物车实现
2017/04/25 Javascript
使用Electron构建React+Webpack桌面应用的方法
2017/12/15 Javascript
Phaser.js实现简单的跑酷游戏附源码下载
2018/10/26 Javascript
javascript中函数的写法实例代码详解
2018/10/28 Javascript
谈谈JavaScript中super(props)的重要性
2019/02/12 Javascript
[09:47]2018DOTA2亚洲邀请赛4.5SOLO赛 No[o]ne vs Sumail
2018/04/06 DOTA
删除目录下相同文件的python代码(逐级优化)
2012/05/25 Python
Python记录详细调用堆栈日志的方法
2015/05/05 Python
Python的Flask框架的简介和安装方法
2015/11/13 Python
详解Python nose单元测试框架的安装与使用
2017/12/20 Python
selenium+python实现自动登录脚本
2018/04/22 Python
Python向Excel中插入图片的简单实现方法
2018/04/24 Python
django商品分类及商品数据建模实例详解
2020/01/03 Python
浅谈django框架集成swagger以及自定义参数问题
2020/07/07 Python
CSS3 二级导航菜单的制作的示例
2018/04/02 HTML / CSS
Delphi工程师笔试题
2013/09/21 面试题
销售代表求职自荐信
2013/10/01 职场文书
自我鉴定范文
2013/11/10 职场文书
十八届三中全会学习方案
2014/02/16 职场文书
生产部厂长职位说明书
2014/03/03 职场文书
2014领导干部四风问题查摆思想汇报
2014/09/13 职场文书
“四风”问题对照检查材料思想汇报
2014/09/16 职场文书
心术观后感
2015/06/11 职场文书
学习十八大的感悟
2015/08/11 职场文书
2016年学校党支部创先争优活动总结
2016/04/05 职场文书
导游词之张家口
2019/12/13 职场文书
Go Gin实现文件上传下载的示例代码
2021/04/02 Golang