python实现mean-shift聚类算法


Posted in Python onJune 10, 2020

本文实例为大家分享了python实现mean-shift聚类算法的具体代码,供大家参考,具体内容如下

1、新建MeanShift.py文件

import numpy as np

# 定义 预先设定 的阈值
STOP_THRESHOLD = 1e-4
CLUSTER_THRESHOLD = 1e-1


# 定义度量函数
def distance(a, b):
 return np.linalg.norm(np.array(a) - np.array(b))


# 定义高斯核函数
def gaussian_kernel(distance, bandwidth):
 return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth)) ** 2)


# mean_shift类
class mean_shift(object):
 def __init__(self, kernel=gaussian_kernel):
  self.kernel = kernel

 def fit(self, points, kernel_bandwidth):

  shift_points = np.array(points)
  shifting = [True] * points.shape[0]

  while True:
   max_dist = 0
   for i in range(0, len(shift_points)):
    if not shifting[i]:
     continue
    p_shift_init = shift_points[i].copy()
    shift_points[i] = self._shift_point(shift_points[i], points, kernel_bandwidth)
    dist = distance(shift_points[i], p_shift_init)
    max_dist = max(max_dist, dist)
    shifting[i] = dist > STOP_THRESHOLD

   if(max_dist < STOP_THRESHOLD):
    break
  cluster_ids = self._cluster_points(shift_points.tolist())
  return shift_points, cluster_ids

 def _shift_point(self, point, points, kernel_bandwidth):
  shift_x = 0.0
  shift_y = 0.0
  scale = 0.0
  for p in points:
   dist = distance(point, p)
   weight = self.kernel(dist, kernel_bandwidth)
   shift_x += p[0] * weight
   shift_y += p[1] * weight
   scale += weight
  shift_x = shift_x / scale
  shift_y = shift_y / scale
  return [shift_x, shift_y]

 def _cluster_points(self, points):
  cluster_ids = []
  cluster_idx = 0
  cluster_centers = []

  for i, point in enumerate(points):
   if(len(cluster_ids) == 0):
    cluster_ids.append(cluster_idx)
    cluster_centers.append(point)
    cluster_idx += 1
   else:
    for center in cluster_centers:
     dist = distance(point, center)
     if(dist < CLUSTER_THRESHOLD):
      cluster_ids.append(cluster_centers.index(center))
    if(len(cluster_ids) < i + 1):
     cluster_ids.append(cluster_idx)
     cluster_centers.append(point)
     cluster_idx += 1
  return cluster_ids

2、调用上述py文件

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 09 11:02:08 2018

@author: muli
"""

from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt 
import random
import numpy as np
import MeanShift


def colors(n):
 ret = []
 for i in range(n):
 ret.append((random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))
 return ret

def main():
 centers = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
 X, _ = make_blobs(n_samples=300, centers=centers, cluster_std=0.4)

 mean_shifter = MeanShift.mean_shift()
 _, mean_shift_result = mean_shifter.fit(X, kernel_bandwidth=0.5)

 np.set_printoptions(precision=3)
 print('input: {}'.format(X))
 print('assined clusters: {}'.format(mean_shift_result))
 color = colors(np.unique(mean_shift_result).size)

 for i in range(len(mean_shift_result)):
  plt.scatter(X[i, 0], X[i, 1], color = color[mean_shift_result[i]])
 plt.show()


if __name__ == '__main__':
 main()

结果如图所示:

python实现mean-shift聚类算法

参考链接

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从零学Python之入门(二)基本数据类型
May 25 Python
python处理PHP数组文本文件实例
Sep 18 Python
浅析Python中将单词首字母大写的capitalize()方法
May 18 Python
fastcgi文件读取漏洞之python扫描脚本
Apr 23 Python
python距离测量的方法
Mar 06 Python
Django实现全文检索的方法(支持中文)
May 14 Python
简单介绍django提供的加密算法
Dec 18 Python
Python进阶之迭代器与迭代器切片教程
Jan 29 Python
Python无头爬虫下载文件的实现
Apr 02 Python
解决echarts中饼图标签重叠的问题
May 16 Python
python利用opencv实现颜色检测
Feb 23 Python
详解python网络进程
Jun 15 Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
You might like
PHP中显示格式化的用户输入
2006/10/09 PHP
基于mysql的论坛(4)
2006/10/09 PHP
一个取得文件扩展名的函数
2006/10/09 PHP
php程序内部post数据的方法
2015/03/31 PHP
php中file_exists函数使用详解
2015/05/08 PHP
PHP常用设计模式之委托设计模式
2016/02/13 PHP
配置Nginx+PHP的正确思路与过程
2016/05/10 PHP
Avengerls vs Newbee BO3 第三场2.18
2021/03/10 DOTA
JavaScript RegExp方法获取地址栏参数(面向对象)
2009/03/10 Javascript
简单实用的js调试logger组件实现代码
2010/11/20 Javascript
jquery获取div距离窗口和父级dv的距离示例
2013/10/10 Javascript
JavaScript实现简单的时钟实例代码
2013/11/23 Javascript
使用jQuery判断IE浏览器版本的代码
2014/06/14 Javascript
利用Node.JS实现邮件发送功能
2016/10/21 Javascript
JQuery 获取多个select标签option的text内容(实例)
2017/09/07 jQuery
python通过urllib2爬网页上种子下载示例
2014/02/24 Python
linux 下实现python多版本安装实践
2014/11/18 Python
详解Python3的TFTP文件传输
2018/06/26 Python
PyQt5实现五子棋游戏(人机对弈)
2020/03/24 Python
30行Python代码实现高分辨率图像导航的方法
2020/05/22 Python
python中get和post有什么区别
2020/06/19 Python
HTML5 表单验证失败的提示语问题
2017/07/13 HTML / CSS
西铁城美国官方网站:Citizen Watch美国
2019/11/08 全球购物
给水排水工程专业毕业生推荐信
2013/10/28 职场文书
兼职学生的自我评价
2013/11/24 职场文书
受欢迎的大学生自我评价
2013/12/05 职场文书
学生党支部先进事迹
2014/02/04 职场文书
《最佳路径》教学反思
2014/04/13 职场文书
产品包装策划方案
2014/05/18 职场文书
招标承诺书
2014/08/30 职场文书
中学生2014国庆节演讲稿:不屈的民族
2014/09/21 职场文书
自查自纠整改报告
2014/11/06 职场文书
单位租车协议书
2015/01/29 职场文书
大学新生入学感想
2015/08/07 职场文书
2016初一新生军训心得体会
2016/01/11 职场文书
go mod 安装依赖 unkown revision问题的解决方案
2021/05/06 Golang