Python实现的Kmeans++算法实例


Posted in Python onApril 26, 2014

1、从Kmeans说起

Kmeans是一个非常基础的聚类算法,使用了迭代的思想,关于其原理这里不说了。下面说一下如何在matlab中使用kmeans算法。

创建7个二维的数据点:

x=[randn(3,2)*.4;randn(4,2)*.5+ones(4,1)*[4 4]];

使用kmeans函数:
class = kmeans(x, 2);

x是数据点,x的每一行代表一个数据;2指定要有2个中心点,也就是聚类结果要有2个簇。 class将是一个具有70个元素的列向量,这些元素依次对应70个数据点,元素值代表着其对应的数据点所处的分类号。某次运行后,class的值是:
 2 
 2
 2
 1
 1
 1
 1

这说明x的前三个数据点属于簇2,而后四个数据点属于簇1。 kmeans函数也可以像下面这样使用:
>> [class, C, sumd, D] = kmeans(x, 2)
class =
     2
     2
     2
     1
     1
     1
     1
C =
    4.0629    4.0845
   -0.1341    0.1201
sumd =
    1.2017
    0.2939
D =
   34.3727    0.0184
   29.5644    0.1858
   36.3511    0.0898
    0.1247   37.4801
    0.7537   24.0659
    0.1979   36.7666
    0.1256   36.2149

class依旧代表着每个数据点的分类;C包含最终的中心点,一行代表一个中心点;sumd代表着每个中心点与所属簇内各个数据点的距离之和;D的每一行也对应一个数据点,行中的数值依次是该数据点与各个中心点之间的距离,Kmeans默认使用的距离是欧几里得距离(参考资料[3])的平方值。kmeans函数使用的距离,也可以是曼哈顿距离(L1-距离),以及其他类型的距离,可以通过添加参数指定。

kmeans有几个缺点(这在很多资料上都有说明):

1、最终簇的类别数目(即中心点或者说种子点的数目)k并不一定能事先知道,所以如何选一个合适的k的值是一个问题。
2、最开始的种子点的选择的好坏会影响到聚类结果。
3、对噪声和离群点敏感。
4、等等。

2、kmeans++算法的基本思路

kmeans++算法的主要工作体现在种子点的选择上,基本原则是使得各个种子点之间的距离尽可能的大,但是又得排除噪声的影响。 以下为基本思路:

1、从输入的数据点集合(要求有k个聚类)中随机选择一个点作为第一个聚类中心
2、对于数据集中的每一个点x,计算它与最近聚类中心(指已选择的聚类中心)的距离D(x)
3、选择一个新的数据点作为新的聚类中心,选择的原则是:D(x)较大的点,被选取作为聚类中心的概率较大
4、重复2和3直到k个聚类中心被选出来
5、利用这k个初始的聚类中心来运行标准的k-means算法

假定数据点集合X有n个数据点,依次用X(1)、X(2)、……、X(n)表示,那么,在第2步中依次计算每个数据点与最近的种子点(聚类中心)的距离,依次得到D(1)、D(2)、……、D(n)构成的集合D。在D中,为了避免噪声,不能直接选取值最大的元素,应该选择值较大的元素,然后将其对应的数据点作为种子点。

如何选择值较大的元素呢,下面是一种思路(暂未找到最初的来源,在资料[2]等地方均有提及,笔者换了一种让自己更好理解的说法): 把集合D中的每个元素D(x)想象为一根线L(x),线的长度就是元素的值。将这些线依次按照L(1)、L(2)、……、L(n)的顺序连接起来,组成长线L。L(1)、L(2)、……、L(n)称为L的子线。根据概率的相关知识,如果我们在L上随机选择一个点,那么这个点所在的子线很有可能是比较长的子线,而这个子线对应的数据点就可以作为种子点。下文中kmeans++的两种实现均是这个原理。

3、python版本的kmeans++

在http://rosettacode.org/wiki/K-means%2B%2B_clustering 中能找到多种编程语言版本的Kmeans++实现。下面的内容是基于python的实现(中文注释是笔者添加的):

from math import pi, sin, cos
from collections import namedtuple
from random import random, choice
from copy import copy
try:
    import psyco
    psyco.full()
except ImportError:
    pass
FLOAT_MAX = 1e100
class Point:
    __slots__ = ["x", "y", "group"]
    def __init__(self, x=0.0, y=0.0, group=0):
        self.x, self.y, self.group = x, y, group
def generate_points(npoints, radius):
    points = [Point() for _ in xrange(npoints)]
    # note: this is not a uniform 2-d distribution
    for p in points:
        r = random() * radius
        ang = random() * 2 * pi
        p.x = r * cos(ang)
        p.y = r * sin(ang)
    return points
def nearest_cluster_center(point, cluster_centers):
    """Distance and index of the closest cluster center"""
    def sqr_distance_2D(a, b):
        return (a.x - b.x) ** 2  +  (a.y - b.y) ** 2
    min_index = point.group
    min_dist = FLOAT_MAX
    for i, cc in enumerate(cluster_centers):
        d = sqr_distance_2D(cc, point)
        if min_dist > d:
            min_dist = d
            min_index = i
    return (min_index, min_dist)
'''
points是数据点,nclusters是给定的簇类数目
cluster_centers包含初始化的nclusters个中心点,开始都是对象->(0,0,0)
'''
def kpp(points, cluster_centers):
    cluster_centers[0] = copy(choice(points)) #随机选取第一个中心点
    d = [0.0 for _ in xrange(len(points))]  #列表,长度为len(points),保存每个点离最近的中心点的距离
    for i in xrange(1, len(cluster_centers)):  # i=1...len(c_c)-1
        sum = 0
        for j, p in enumerate(points):
            d[j] = nearest_cluster_center(p, cluster_centers[:i])[1] #第j个数据点p与各个中心点距离的最小值
            sum += d[j]
        sum *= random()
        for j, di in enumerate(d):
            sum -= di
            if sum > 0:
                continue
            cluster_centers[i] = copy(points[j])
            break
    for p in points:
        p.group = nearest_cluster_center(p, cluster_centers)[0]
'''
points是数据点,nclusters是给定的簇类数目
'''
def lloyd(points, nclusters):
    cluster_centers = [Point() for _ in xrange(nclusters)]  #根据指定的中心点个数,初始化中心点,均为(0,0,0)
    # call k++ init
    kpp(points, cluster_centers)   #选择初始种子点
    # 下面是kmeans
    lenpts10 = len(points) >> 10
    changed = 0
    while True:
        # group element for centroids are used as counters
        for cc in cluster_centers:
            cc.x = 0
            cc.y = 0
            cc.group = 0
        for p in points:
            cluster_centers[p.group].group += 1  #与该种子点在同一簇的数据点的个数
            cluster_centers[p.group].x += p.x
            cluster_centers[p.group].y += p.y
        for cc in cluster_centers:    #生成新的中心点
            cc.x /= cc.group
            cc.y /= cc.group
        # find closest centroid of each PointPtr
        changed = 0  #记录所属簇发生变化的数据点的个数
        for p in points:
            min_i = nearest_cluster_center(p, cluster_centers)[0]
            if min_i != p.group:
                changed += 1
                p.group = min_i
        # stop when 99.9% of points are good
        if changed <= lenpts10:
            break
    for i, cc in enumerate(cluster_centers):
        cc.group = i
    return cluster_centers
def print_eps(points, cluster_centers, W=400, H=400):
    Color = namedtuple("Color", "r g b");
    colors = []
    for i in xrange(len(cluster_centers)):
        colors.append(Color((3 * (i + 1) % 11) / 11.0,
                            (7 * i % 11) / 11.0,
                            (9 * i % 11) / 11.0))
    max_x = max_y = -FLOAT_MAX
    min_x = min_y = FLOAT_MAX
    for p in points:
        if max_x < p.x: max_x = p.x
        if min_x > p.x: min_x = p.x
        if max_y < p.y: max_y = p.y
        if min_y > p.y: min_y = p.y
    scale = min(W / (max_x - min_x),
                H / (max_y - min_y))
    cx = (max_x + min_x) / 2
    cy = (max_y + min_y) / 2
    print "%%!PS-Adobe-3.0\n%%%%BoundingBox: -5 -5 %d %d" % (W + 10, H + 10)
    print ("/l {rlineto} def /m {rmoveto} def\n" +
           "/c { .25 sub exch .25 sub exch .5 0 360 arc fill } def\n" +
           "/s { moveto -2 0 m 2 2 l 2 -2 l -2 -2 l closepath " +
           "   gsave 1 setgray fill grestore gsave 3 setlinewidth" +
           " 1 setgray stroke grestore 0 setgray stroke }def")
    for i, cc in enumerate(cluster_centers):
        print ("%g %g %g setrgbcolor" %
               (colors[i].r, colors[i].g, colors[i].b))
        for p in points:
            if p.group != i:
                continue
            print ("%.3f %.3f c" % ((p.x - cx) * scale + W / 2,
                                    (p.y - cy) * scale + H / 2))
        print ("\n0 setgray %g %g s" % ((cc.x - cx) * scale + W / 2,
                                        (cc.y - cy) * scale + H / 2))
    print "\n%%%%EOF"
def main():
    npoints = 30000
    k = 7 # # clusters
    points = generate_points(npoints, 10)
    cluster_centers = lloyd(points, k)
    print_eps(points, cluster_centers)
main()

上述代码实现的算法是针对二维数据的,所以Point对象有三个属性,分别是在x轴上的值、在y轴上的值、以及所属的簇的标识。函数lloyd是kmeans++算法的整体实现,其先是通过kpp函数选取合适的种子点,然后对数据集实行kmeans算法进行聚类。kpp函数的实现完全符合上述kmeans++的基本思路的2、3、4步。

4、matlab版本的kmeans++

function [L,C] = kmeanspp(X,k)
%KMEANS Cluster multivariate data using the k-means++ algorithm.
%   [L,C] = kmeans_pp(X,k) produces a 1-by-size(X,2) vector L with one class
%   label per column in X and a size(X,1)-by-k matrix C containing the
%   centers corresponding to each class.
%   Version: 2013-02-08
%   Authors: Laurent Sorber (Laurent.Sorber@cs.kuleuven.be)
L = [];
L1 = 0;
while length(unique(L)) ~= k
    % The k-means++ initialization.
    C = X(:,1+round(rand*(size(X,2)-1))); %size(X,2)是数据集合X的数据点的数目,C是中心点的集合
    L = ones(1,size(X,2));
    for i = 2:k
        D = X-C(:,L); %-1 
        D = cumsum(sqrt(dot(D,D,1))); %将每个数据点与中心点的距离,依次累加
        if D(end) == 0, C(:,i:k) = X(:,ones(1,k-i+1)); return; end
        C(:,i) = X(:,find(rand < D/D(end),1)); %find的第二个参数表示返回的索引的数目
        [~,L] = max(bsxfun(@minus,2*real(C'*X),dot(C,C,1).')); %碉堡了,这句,将每个数据点进行分类。
    end
    % The k-means algorithm.
    while any(L ~= L1)
        L1 = L;
        for i = 1:k, l = L==i; C(:,i) = sum(X(:,l),2)/sum(l); end
        [~,L] = max(bsxfun(@minus,2*real(C'*X),dot(C,C,1).'),[],1);
    end
end

这个函数的实现有些特殊,参数X是数据集,但是是将每一列看做一个数据点,参数k是指定的聚类数。返回值L标记了每个数据点的所属分类,返回值C保存了最终形成的中心点(一列代表一个中心点)。测试一下:

>> x=[randn(3,2)*.4;randn(4,2)*.5+ones(4,1)*[4 4]]
x =
   -0.0497    0.5669
    0.5959    0.2686
    0.5636   -0.4830
    4.3586    4.3634
    4.8151    3.8483
    4.2444    4.1469
    4.5173    3.6064
>> [L, C] = kmeanspp(x',2)
L =
     2     2     2     1     1     1     1
C =
    4.4839    0.3699
    3.9913    0.1175

好了,现在开始一点点理解这个实现,顺便巩固一下matlab知识。

unique函数用来获取一个矩阵中的不同的值,示例:

>> unique([1 3 3 4 4 5])
ans =
     1     3     4     5
>> unique([1 3 3 ; 4 4 5])
ans =
     1
     3
     4
     5

所以循环 while length(unique(L)) ~= k 以得到了k个聚类为结束条件,不过一般情况下,这个循环一次就结束了,因为重点在这个循环中。

rand是返回在(0,1)这个区间的一个随机数。在注释%-1所在行,C被扩充了,被扩充的方法类似于下面:

>> C =[];
>> C(1,1) = 1
C =
     1
>> C(2,1) = 2
C =
     1
     2
>> C(:,[1 1 1 1])
ans =
     1     1     1     1
     2     2     2     2
>> C(:,[1 1 1 1 2])
Index exceeds matrix dimensions.

C中第二个参数的元素1,其实是代表C的第一列数据,之所以在值2时候出现Index exceeds matrix dimensions.的错误,是因为C本身没有第二列。如果C有第二列了:

>> C(2,2) = 3;
>> C(2,2) = 4;
>> C(:,[1 1 1 1 2])
ans =
     1     1     1     1     3
     2     2     2     2     4

dot函数是将两个矩阵点乘,然后把结果在某一维度相加:
>> TT = [1 2 3 ; 4 5 6];
>> dot(TT,TT)
ans =
    17    29    45
>> dot(TT,TT,1 )
ans =
    17    29    45

<code>cumsum</code>是累加函数:
>> cumsum([1 2 3])
ans =
     1     3     6
>> cumsum([1 2 3; 4 5 6])
ans =
     1     2     3
     5     7     9

max函数可以返回两个值,第二个代表的是max数的索引位置:
>> [~, L] = max([1 2 3])
L =
     3
>> [~,L] = max([1 2 3;2 3 4])
L =
     2     2     2

其中~是占位符。

关于bsxfun函数,官方文档指出:

C = bsxfun(fun,A,B) applies the element-by-element binary operation specified by the function handle fun to arrays A and B, with singleton expansion enabled

其中参数fun是函数句柄,关于函数句柄见资料[9]。下面是bsxfun的一个示例:
>> A= [1 2 3;2 3 4]
A =
     1     2     3
     2     3     4
>> B=[6;7]
B =
     6
     7
>> bsxfun(@minus,A,B)
ans =
    -5    -4    -3
    -5    -4    -3

对于:
[~,L] = max(bsxfun(@minus,2*real(C'*X),dot(C,C,1).'));

max的参数是这样一个矩阵,矩阵有n列,n也是数据点的个数,每一列代表着对应的数据点与各个中心点之间的距离的相反数。不过这个距离有些与众不同,算是欧几里得距离的变形。

假定数据点是2维的,某个数据点为(x1,y1),某个中心点为(c1,d1),那么通过bsxfun(@minus,2real(C'X),dot(C,C,1).')的计算,数据点与中心点的距离为2c1x1 + 2d1y1 -c1.^2 - c2.^2,可以变换为x1.^2 + y1.^2 - (c1-x1).^2 - (d1-y1).^2。对于每一列而言,由于是数据点与各个中心点之间的计算,所以可以忽略x1.^2 + y1.^2,最终计算结果是欧几里得距离的平方的相反数。这也说明了使用max的合理性,因为一个数据点的所属簇取决于与其距离最近的中心点,若将距离取相反数,则应该是值最大的那个点。

Python 相关文章推荐
python使用Tkinter显示网络图片的方法
Apr 24 Python
Python编写生成验证码的脚本的教程
May 04 Python
Python计算三角函数之asin()方法的使用
May 15 Python
Python解析树及树的遍历
Feb 03 Python
Python实现OpenCV的安装与使用示例
Mar 30 Python
Python实现的简单读写csv文件操作示例
Jul 12 Python
Python多线程原理与用法详解
Aug 20 Python
Python面向对象之类的定义与继承用法示例
Jan 14 Python
Python检查ping终端的方法
Jan 26 Python
Python实现微信消息防撤回功能的实例代码
Apr 29 Python
python爬虫项目设置一个中断重连的程序的实现
Jul 26 Python
python 制作一个gui界面的翻译工具
May 14 Python
爬山算法简介和Python实现实例
Apr 26 #Python
Python操作sqlite3快速、安全插入数据(防注入)的实例
Apr 26 #Python
python实现的二叉树算法和kmp算法实例
Apr 25 #Python
python中的__init__ 、__new__、__call__小结
Apr 25 #Python
Python yield 小结和实例
Apr 25 #Python
python计数排序和基数排序算法实例
Apr 25 #Python
python处理圆角图片、圆形图片的例子
Apr 25 #Python
You might like
WIN98下Apache1.3.14+PHP4.0.4的安装
2006/10/09 PHP
PHP在线生成二维码(google api)的实现代码详解
2013/06/04 PHP
ThinkPHP3.1基础知识快速入门
2014/06/19 PHP
Laravel中扩展Memcached缓存驱动实现使用阿里云OCS缓存
2015/02/10 PHP
php中get_cfg_var()和ini_get()的用法及区别
2015/03/04 PHP
php解决安全问题的方法实例
2019/09/19 PHP
Asp.net下利用Jquery Ajax实现用户注册检测(验证用户名是否存)
2010/09/12 Javascript
JQuery的Ajax请求实现局部刷新的简单实例
2014/02/11 Javascript
JS兼容浏览器的导出Excel(CSV)文件的方法
2014/05/03 Javascript
js实现的彩色方块飞舞奇幻效果
2016/01/27 Javascript
全面解析Bootstrap中form、navbar的使用方法
2016/05/30 Javascript
AngularJs Understanding the Controller Component
2016/09/02 Javascript
JS判断是否手机或pad访问实现方法
2016/12/09 Javascript
jquery插件ContextMenu设置右键菜单
2017/03/13 Javascript
json数据传到前台并解析展示成列表的方法
2018/08/06 Javascript
使用JS实现导航切换时高亮显示的示例讲解
2018/08/22 Javascript
vue-router传参用法详解
2019/01/19 Javascript
angular6根据environments配置文件更改开发所需要的环境的方法
2019/03/06 Javascript
Vue 实现复制功能,不需要任何结构内容直接复制方式
2019/11/09 Javascript
python使用rabbitmq实现网络爬虫示例
2014/02/20 Python
python添加模块搜索路径方法
2017/09/11 Python
python快速建立超简单的web服务器的实现方法
2018/02/17 Python
Python实现的生产者、消费者问题完整实例
2018/05/30 Python
Python 旋转打印各种矩形的方法
2019/07/09 Python
keras中的loss、optimizer、metrics用法
2020/06/15 Python
用python实现一个简单的验证码
2020/12/09 Python
璀璨的珍珠、密钉和个性化珠宝:Lily & Roo
2021/01/21 全球购物
法律工作求职自荐信
2013/10/31 职场文书
优秀家长自荐材料
2014/08/26 职场文书
田径运动会通讯稿
2014/09/13 职场文书
房屋租赁合同协议书范本
2014/10/19 职场文书
2015年试用期工作总结范文
2015/05/28 职场文书
如何让2019年上半年的工作总结更出色!
2019/07/01 职场文书
《卧薪尝胆》读后感3篇
2019/12/26 职场文书
制作能在nginx和IIS中使用的ssl证书
2021/06/21 Servers
与Windows10相比Windows11有哪些改进?值不值得升级?
2021/11/21 数码科技