Python实现Mean Shift算法

2023-11-12

       声明:代码的运行环境为Python3。Python3与Python2在一些细节上会有所不同,希望广大读者注意。本博客以代码为主,代码中会有详细的注释。相关文章将会发布在我的个人博客专栏《Python从入门到深度学习》,欢迎大家关注~


       在K-Means算法中,聚类的类别个数需要提前指定,对于类别个数未知的数据集,K-Means算法和K-Means++算法将很难对其进行求解,所以需要一些能够处理未知类别个数的算法来处理此类问题。Mean Shift算法,又称作均值漂移算法,它跟K-Means算法一样,都是基于聚类中心的聚类算法,不同的是,它不需要提前指定聚类中心的个数,聚类中心是通过在给定区域中样本的均值来确定的,通过不断更新聚类中心,直至聚类中心不再改变为止。

一、Mean Shift向量与核函数

1、Mean Shift向量

       对于给定的n维空间中的m个样本点,对于其中的一个样本X,其Mean Shift的向量为:

       其中,指的是一个半径为h的高维球区域,定义为:

2、核函数

       通过上述方式求出的Mean Shift向量时存在问题的,即在区域内每一个对样本X的贡献是一样的,然而实际上,每一个样本对样本X的贡献是不一样的,我们可以通过核函数对每一个样本的贡献进行度量。

       核函数的定义如下:

       设Z是输入空间,H是特征空间,如果存在一个Z到H的映射:使得所有,函数满足条件:,则称为核函数,为映射函数。

       我们在Mean Shift算法中使用的是高斯核函数,这也是最常用的核函数之一,高斯核函数的表达式为:

       其中,h为带宽,当带宽一定时,样本点之间的距离越近,核函数的值越大;当样本点距离一定时,带宽越大,核函数的值越小。

       下面我们使用Python代码实现高斯核函数:

import numpy as np
import math

def gs_kernel(dist, h):
    '''
    高斯核函数
    :param dist: 欧氏距离
    :param h: 带宽
    :return: 返回高斯核函数的值
    '''
    m = np.shape(dist)[0]  # 样本个数
    one = 1 / (h * math.sqrt(2 * math.pi))
    two = np.mat(np.zeros((m, 1)))
    for i in range(m):
        two[i, 0] = (-0.5 * dist[i] * dist[i].T) / (h * h)
        two[i, 0] = np.exp(two[i, 0])
    
    gs_val = one * two
    return gs_val

二、Mean Shift原理

       在Mean Shift中通过迭代的方式找到最终的聚类中心,即对每一个样本点计算其漂移均值,以计算出来的漂移均值点作为新的起始点重复上述步骤,直到满足终止条件,得到的最终的漂移均值点即为最终的聚类中心。

       Mean Shift算法实现过程如下:

def mean_shift(points, h=2, MIN_DISTANCE=0.000001):
    '''
    训练Mean Shift模型
    :param points: 特征点
    :param h: 带宽
    :param MIN_DISTANCE: 最小误差
    :return: 返回特征点、均值漂移点、类别
    '''
    mean_shift_points = np.mat(points)
    max_min_dist = 1
    iteration = 0  # 迭代的次数
    m = np.shape(mean_shift_points)[0]  # 样本的个数
    need_shift = [True] * m  # 标记是否需要漂移

    # 计算均值漂移向量
    while max_min_dist > MIN_DISTANCE:
        max_min_dist = 0
        iteration += 1
        print("iteration : " + str(iteration))
        for i in range(0, m):
            if not need_shift[i]:  # 判断每一个样本点是否需要计算偏移均值
                continue
            point_new = mean_shift_points[i]
            point_new_start = point_new
            point_new  = shift_point(point_new, points, h)  # 对样本点进行漂移计算
            dist = distince(point_new, point_new_start)  # 计算该点与漂移后的点之间的距离
            
            if dist > max_min_dist:
                max_min_dist = dist
            if dist < MIN_DISTANCE:
                need_shift[i] = False
            
            mean_shift_points[i] = point_new
    # 计算最终的类别
    lb = lb_points(mean_shift_points)  # 计算所属的类别
    return np.mat(points), mean_shift_points, lb

       其中,shift_point()方法目的在于计算漂移量,lb_points()方法的目的在于计算最终所属分类,distance()方法用于计算欧氏距离,三个方法的实现过程分别如下:

(1)shift_point()方法

def shift_point(point, points, h):
    '''
    计算漂移向量
    :param point: 需要计算的点
    :param points: 所有的样本点
    :param h: 带宽
    :return: 返回漂移后的点
    '''
    points = np.mat(points)
    m = np.shape(points)[0]  # 样本的个数
    # 计算距离
    point_dist = np.mat(np.zeros((m, 1)))
    for i in range(m):
        point_dist[i, 0] = distince(point, points[i])

    # 计算高斯核函数
    point_weights = gs_kernel(point_dist, h)

    # 计算分母
    all_sum = 0.0
    for i in range(m):
        all_sum += point_weights[i, 0]

    # 计算均值偏移
    point_shifted = point_weights.T * points / all_sum
    return point_shifted

(2)lb_points()

def lb_points(mean_shift_points):
    '''
    计算所属类别
    :param mean_shift_points: 漂移向量
    :return: 返回所属的类别
    '''
    lb_list = []
    m, n = np.shape(mean_shift_points)
    index = 0
    index_dict = {}
    for i in range(m):
        item = []
        for j in range(n):
            item.append(str(("%5.2f" % mean_shift_points[i, j])))

        item_1 = "_".join(item)
        if item_1 not in index_dict:
            index_dict[item_1] = index
            index += 1

    for i in range(m):
        item = []
        for j in range(n):
            item.append(str(("%5.2f" % mean_shift_points[i, j])))

        item_1 = "_".join(item)
        lb_list.append(index_dict[item_1])
    return lb_list

(3)distince()方法

def distince(pointA, pointB):
    '''
    计算欧氏距离
    :param pointA: A点坐标
    :param pointB: B点坐标
    :return: 返回得到的欧氏距离
    '''
    return math.sqrt((pointA - pointB) * (pointA - pointB).T)

三、Mean Shift算法举例

1、数据集:数据集含有两个特征,如下图所示:

2、加载数据集

       我们此处使用如下方法加载数据集,也可使用其他的方式进行加载,此处可以参考我的另外一篇文章《Python两种方式加载文件内容》。加载文件内容代码如下:

def load_data(path, feature_num=2):
    '''
    导入数据
    :param path: 路径
    :param feature_num: 特行总数
    :return: 返回数据特征
    '''
    f = open(path)
    data = []
    for line in f.readlines():
        lines = line.strip().split("\t")
        data_tmp = []
        if len(lines) != feature_num:  # 判断特征的个数是否正确,把不符合特征数的数据去除
            continue
        for i in range(feature_num):
            data_tmp.append(float(lines[i]))
        data.append(data_tmp)
    f.close()
    return data

3、保存聚类结果

       通过Mean Shift聚类后,我们使用一下方法进行聚类结果的保存

def save_result(file_name, data):
    '''
    保存聚类结果
    :param file_name: 保存的文件名
    :param data: 需要保存的文件
    :return:
    '''
    f = open(file_name, "w")
    m, n = np.shape(data)
    for i in range(m):
        tmp = []
        for j in range(n):
            tmp.append(str(data[i, j]))
        f.write("\t".join(tmp) + "\n")
    f.close()

4、调用Mean Shift算法

if __name__ == "__main__":
    data = load_data("F://data", 2)
    points, shift_points, cluster = mean_shift(data, 2)
    save_result("sub", np.mat(cluster))
    save_result("center", shift_points)

5、结果展示

       得到的聚类结果如下所示:

        你们在此过程中遇到了什么问题,欢迎留言,让我看看你们都遇到了哪些问题。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Python实现Mean Shift算法 的相关文章

  • 使用 pythonbrew 编译 Python 3.2 和 2.7 时出现问题

    我正在尝试使用构建多个版本的 python蟒蛇酿造 http pypi python org pypi pythonbrew 0 7 3 但我遇到了一些测试失败 这是在运行的虚拟机上 Ubuntu 8 04 32 位 当我使用时会发生这种情
  • 将数据从 python pandas 数据框导出或写入 MS Access 表

    我正在尝试将数据从 python pandas 数据框导出到现有的 MS Access 表 我想用已更新的数据替换 MS Access 表 在 python 中 我尝试使用 pandas to sql 但收到错误消息 我觉得很奇怪 使用 p
  • 将 Matplotlib 误差线放置在不位于条形中心的位置

    我正在 Matplotlib 中生成带有错误栏的堆积条形图 不幸的是 某些层相对较小且数据多样 因此多个层的错误条可能重叠 从而使它们难以或无法读取 Example 有没有办法设置每个误差条的位置 即沿 x 轴移动它 以便重叠的线显示在彼此
  • Django:按钮链接

    我是一名 Django 新手用户 尝试创建一个按钮 单击该按钮会链接到我网站中的另一个页面 我尝试了一些不同的例子 但似乎没有一个对我有用 举个例子 为什么这不起作用
  • Flask 会话变量

    我正在用 Flask 编写一个小型网络应用程序 当两个用户 在同一网络下 尝试使用应用程序时 我遇到会话变量问题 这是代码 import os from flask import Flask request render template
  • 使用 Tkinter 显示 numpy 数组中的图像

    我对 Python 缺乏经验 第一次使用 Tkinter 制作一个 UI 显示我的数字分类程序与 mnist 数据集的结果 当图像来自 numpy 数组而不是我的 PC 上的文件路径时 我有一个关于在 Tkinter 中显示图像的问题 我为
  • Python pickle:腌制对象不等于源对象

    我认为这是预期的行为 但想检查一下 也许找出原因 因为我所做的研究结果是空白 我有一个函数可以提取数据 创建自定义类的新实例 然后将其附加到列表中 该类仅包含变量 然后 我使用协议 2 作为二进制文件将该列表腌制到文件中 稍后我重新运行脚本
  • Python 函数可以从作用域之外赋予新属性吗?

    我不知道你可以这样做 def tom print tom s locals locals def dick z print z name z name z guest Harry print z guest z guest print di
  • AWS EMR Spark Python 日志记录

    我正在 AWS EMR 上运行一个非常简单的 Spark 作业 但似乎无法从我的脚本中获取任何日志输出 我尝试过打印到 stderr from pyspark import SparkContext import sys if name m
  • 在f字符串中转义字符[重复]

    这个问题在这里已经有答案了 我遇到了以下问题f string gt gt gt a hello how to print hello gt gt gt f a a gt gt gt f a File
  • Pandas:merge_asof() 对多行求和/不重复

    我正在处理两个数据集 每个数据集具有不同的关联日期 我想合并它们 但因为日期不完全匹配 我相信merge asof 是最好的方法 然而 有两件事发生merge asof 不理想的 数字重复 数字丢失 以下代码是一个示例 df a pd Da
  • 向 Altair 图表添加背景实心填充

    I like Altair a lot for making graphs in Python As a tribute I wanted to regenerate the Economist graph s in Mistakes we
  • 如何在seaborn displot中使用hist_kws

    我想在同一图中用不同的颜色绘制直方图和 kde 线 我想为直方图设置绿色 为 kde 线设置蓝色 我设法弄清楚使用 line kws 来更改 kde 线条颜色 但 hist kws 不适用于显示 我尝试过使用 histplot 但我无法为
  • 每个 X 具有多个 Y 值的 Python 散点图

    我正在尝试使用 Python 创建一个散点图 其中包含两个 X 类别 cat1 cat2 每个类别都有多个 Y 值 如果每个 X 值的 Y 值的数量相同 我可以使用以下代码使其工作 import numpy as np import mat
  • 如何计算 pandas 数据帧上的连续有序值

    我试图从给定的数据帧中获取连续 0 值的最大计数 其中包含来自 pandas 数据帧的 id date value 列 如下所示 id date value 354 2019 03 01 0 354 2019 03 02 0 354 201
  • 在 Qt 中自动调整标签文本大小 - 奇怪的行为

    在 Qt 中 我有一个复合小部件 它由排列在 QBoxLayouts 内的多个 QLabels 组成 当小部件调整大小时 我希望标签文本缩放以填充标签区域 并且我已经在 resizeEvent 中实现了文本大小的调整 这可行 但似乎发生了某
  • Python 类继承 - 诡异的动作

    我观察到类继承有一个奇怪的效果 对于我正在处理的项目 我正在创建一个类来充当另一个模块的类的包装器 我正在使用第 3 方 aeidon 模块 用于操作字幕文件 但问题可能不太具体 以下是您通常如何使用该模块 project aeidon P
  • 导入错误:没有名为 site 的模块 - mac

    我已经有这个问题几个月了 每次我想获取一个新的 python 包并使用它时 我都会在终端中收到此错误 ImportError No module named site 我不知道为什么会出现这个错误 实际上 我无法使用任何新软件包 因为每次我
  • 如何将输入读取为数字?

    这个问题的答案是社区努力 help privileges edit community wiki 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 Why are x and y下面的代码中使用字符串而不是整数 注意 在Python 2
  • NotImplementedError:无法将符号张量 (lstm_2/strided_slice:0) 转换为 numpy 数组。时间

    张量流版本 2 3 1 numpy 版本 1 20 在代码下面 define model model Sequential model add LSTM 50 activation relu input shape n steps n fe

随机推荐

  • LeetCode 20. 有效的括号

    题目链接 https leetcode cn problems valid parentheses C 代码如下 class Solution public bool isValid string s stack
  • The Standard C Library

    C的标志库函数是学习和使用C语言的基础 是编写经典C程序的基础 是学习其他计算机知识的基础 C标志库中一共包含了15个头文件
  • matlab求lypunov,【原创】Lyapunov、Sylvester和Riccati方程的Matlab求解

    Lyapunov Sylvester和Riccati方程是控系统常用到的几个方程 应用和计算比较广泛 在这里我们只要讨论下Lypunov方程的连续方程 离散方程的数值和解析解法 其中数值解法MATLAB提供的直接的lyap 和dlyap 函
  • ClassNotFoundException:NullPointerException:ArrayIndexOutOfBounException:FileNotFoundException:等异常

    目录 1 ClassNotFoundException 解决方法 2 NullPointerException 解决方法 3 ArrayIndexOutOfBoundsException 解决方法 4 FileNotFoundExcepti
  • EasySwoole ElasticSearch打造高性能小视频服务系统

    好久没有更新教程 现在更新一套缓存视频给大家 Elasticsearch的索引思路 将磁盘里的东西尽量搬进内存 减少磁盘随机读取次数 同时也利用磁盘顺序读特性 结合各种奇技淫巧的压缩算法 用及其苛刻的态度使用内存 所以 对于使用Elasti
  • 操作系统-管道通信

    编写程序 演示多进程并发执行和进程软中断 管道通信 父进程使用系统调用pipe 建立一个管道 然后使用系统调用fork 创建两个子进程 子进程1和子进程2 子进程1每隔1秒通过管道向子进程2发送数据 I send you x times x
  • C++ STL 集合set

    本文主要简述集合的原理和用法 便于快速学习和查阅 集合的原理 set是一个内部自动有序且不含重复元素的容器 set集合容器实现了红黑树 Red Black Tree 的平衡二叉检索树的数据结构 在插入元素时 它会自动调整二叉树的排列 把该元
  • Spring AOP:面向切面编程的简介和实践

    目录 一 什么是AOP 二 AOP的核心概念 三 Spring AOP的实现方式 第一种 注解配置AOP 第二种 xml配置AOP 一 什么是AOP AOP Aspect Oriented Programming 即面向切面编程 是一种编程
  • 在.NET中使用正则表达式(入门篇)

    转载请注明 敏捷学院 技术资源库 原文链接 http dev mjxy cn a In NET using regular expressions aspx 代码下载 RegexExample zip介绍正则表达式提供了功能强大 灵活而又高
  • ITest:京东数科接口自动化测试实践

    导读 你是否为每天 点点点 的工作而感到索然无味 你是否苦于没有合适的工具而对复杂的测试任务望而却步 频繁变动的接口 重复的功能测试 你 疲惫么 京东数科平台开发团队基于日常接口测试经验 开发了接口测试平台 ITest 通过此平台让研发流程
  • [OpenAirInterface实战-20] :OAI 软件无线电USRP E310硬件详解

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 https blog csdn net HiWangWenBing article details 121094384 第1章 概述 USR
  • ChatGPT 或其它 AI,能用在文书创作上吗?

    新的申请季已经正式开始 一些热门项目的ED截止日期也不再遥远 因此很多准留学生们都已经开始了关于文书的创作 而随着科技的不断发展 以ChatGPT为首的一众AI工具也作为一种辅助手段愈发融入了我们的生活 那么不免就会有一些同学在准备申请时
  • excel基本操作1

    excel隔行设置样式 条件格式 gt 条件规则 gt 输入公式 参考https jingyan baidu com article 36d6ed1f2379c35acf4883e0 html excel隔列取值 使用Index结合row和
  • 无线专题 osi模型、TCP/IP五层模型、网络编程(一)

    一 OSI介绍 1 OSI的来源 OSI Open System Interconnect 即开放式系统互联 一般都叫OSI参考模型 是ISO 国际标准化组织 组织在1985年研究的网络互连模型 ISO为了更好的使网络应用更为普及 推出了O
  • Kafka消费者详解

    一 Kafka消费者的消费模式 当生产者将消息发送到Kafka集群后 会转发给消费者进行消费 消息的消费模型有两种 推送模式 push 和拉取模式 pull 1 消息的推送模式 消息的推送模式需要记录消费者的消费状态 当把一条消息推送给消费
  • u盘刷新系统

    1 百度u盘制作将u盘进行刷成系统盘 点击添加系统 确认 关掉即可 到这里就制作完成了 u盘里也有系统了 下一步就是进入电脑的 bios 一般是f8 或者f2 或者esc 看你是什么电脑自己手机百度一下 当进入u盘系统时候会发现一键刷机工具
  • 【计算机网络】HTTP协议详解(八):HTTP网关

    HTTP网关 文章目录 HTTP网关 一 网关 Gateway 二 网关的分类 1 HTTP 服务器端网关 2 HTTP 客户端网关 3 HTTP HTTPS 服务器端安全网关 4 HTTPS HTTP 客户端安全加速器网关 5 资源网关
  • sshpass

    sshpass 安装 sshpass Linux 软件工具安装 源码安装 测试 sshpass 在使用 ssh scp 等命令进行远程操作的时候 必须手动输入密码 这就为自动化的执行造成困扰 sshpass 可以在命令行直接使用密码来进行远
  • Android Studio 常用快捷按键

    大小写转换 Cmd Shift U Ctrl Shift U 注释代码 Cmd Ctrl 注释代码 Cmd Option Ctrl Alt 格式化代码 Cmd Option L Ctrl Alt L 清除无效包引用 Option Contr
  • Python实现Mean Shift算法

    声明 代码的运行环境为Python3 Python3与Python2在一些细节上会有所不同 希望广大读者注意 本博客以代码为主 代码中会有详细的注释 相关文章将会发布在我的个人博客专栏 Python从入门到深度学习 欢迎大家关注 在K Me