机器学习二-kmeans-kdtree

2023-11-02

机器学习纯java代码,点击打开链接

 KD树介绍http://www.pelleg.org/shared/hp/kmeans.html

我们的数据集也是从5高斯分布中随机生成的8000个点。你应该看到底层的Gaussians。蓝色边界表示“根”kd节点。它涵盖了所有要点。


现在看到根节点的孩子。每个都是一个矩形,分割线与Y轴大约相隔半个平行线。


现在你看到根的大孩子了。每一个都是它的父母,这次沿着X轴的分裂。


等等,在交替的维度上分裂...








这里是KD-tree的前七层,全都在一张照片中。

之后呢,怎么加速的呢?

第一步:构建kdtree

第二部:kmeans++初始化一次,选取k个中心点(centroids)作为初始点

第三部:从kdtree的根节点开始,选取一个离kdtree的节点(node.center)最近的centroids点作为最优点中心点  c,其他点作为 c_0

第五步:kdtree中逐个(lower,upper)节点遍历,直到所在节点中的max值都不属于其他中心点(

(x-c).(x-c) < (x-c_0).(x-c_0)  推导出 (c-c_0).(c-c_0) < 2(x-c_0).(c-c_0)

),则这用这个节点下的所有值进行重新计算中心点

第六步:直到损失出现较小的值

if (distortion <= dist) {
    break;
} else {
    distortion = dist;
}


代码如下

构建 kmenas

public KMeans(double[][] data, int k, int maxIter, int runs) {
    if (k < 2) {
        throw new IllegalArgumentException("Invalid number of clusters: " + k);
    }

    if (maxIter <= 0) {
        throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
    }

    if (runs <= 0) {
        throw new IllegalArgumentException("Invalid number of runs: " + runs);
    }

    BBDTree bbd = new BBDTree(data);

    List<KMeansThread> tasks = new ArrayList<>();
    for (int i = 0; i < runs; i++) {
        tasks.add(new KMeansThread(bbd, data, k, maxIter));
    }

    KMeans best = new KMeans();
    best.distortion = Double.MAX_VALUE;

    try {
        List<KMeans> clusters = MulticoreExecutor.run(tasks);
        for (KMeans kmeans : clusters) {
            if (kmeans.distortion < best.distortion) {
                best = kmeans;
            }
        }
    } catch (Exception ex) {
        logger.error("Failed to run K-Means on multi-core", ex);

        for (int i = 0; i < runs; i++) {
            KMeans kmeans = lloyd(data, k, maxIter);
            if (kmeans.distortion < best.distortion) {
                best = kmeans;
            }
        }
    }

    this.k = best.k;
    this.distortion = best.distortion;
    this.centroids = best.centroids;
    this.y = best.y;
    this.size = best.size;
}

simle  kdtree

/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.clustering;

import java.util.Arrays;

import smile.math.Math;

/**
 * Balanced Box-Decomposition Tree. BBD tree is a specialized k-d tree that
 * vastly speeds up an iteration of k-means. This is used internally by KMeans
 * and batch SOM., and will most likely not need to be used directly.
 * <p>
 * The structure works as follows:
 * <ul>
 * <li> All data data are placed into a tree where we choose child nodes by
 * partitioning all data data along a plane parallel to the axis.
 * <li> We maintain for each node, the bounding box of all data data stored
 * at that node.
 * <li> To do a k-means iteration, we need to assign data to clusters and
 * calculate the sum and the number of data assigned to each cluster.
 * For each node in the tree, we can rule out some cluster centroids as
 * being too far away from every single point in that bounding box.
 * Once only one cluster is left, all data in the node can be assigned
 * to that cluster in batch.
 * </ul>
 * <p>
 * <h2>References</h2>
 * <ol>
 * <li>Tapas Kanungo, David M. Mount, Nathan S. Netanyahu, Christine D. Piatko, Ruth Silverman, and Angela Y. Wu. An Efficient k-Means Clustering Algorithm: Analysis and Implementation. IEEE TRANS. PAMI, 2002.</li>
 * </ol>
 *
 * @author Haifeng Li
 * @see KMeans
 * @see smile.vq.SOM
 */
public class BBDTree {

    class Node {
        /**
         * The number of data stored in this node.
         */
        int count;
        /**
         * The smallest point index stored in this node.
         */
        int index;
        /**
         * The center/mean of bounding box.
         */
        double[] center;
        /**
         * The half side-lengths of bounding box.
         */
        double[] radius;
        /**
         * The sum of the data stored in this node.
         */
        double[] sum;
        /**
         * The min cost for putting all data in this node in 1 cluster
         */
        double cost;
        /**
         * The child node of lower half box.
         */
        Node lower;
        /**
         * The child node of upper half box.
         */
        Node upper;

        /**
         * Constructor.
         *
         * @param d the dimension of vector space.
         */
        Node(int d) {
            center = new double[d];
            radius = new double[d];
            sum = new double[d];
        }
    }

    /**
     * Root node.
     */
    private Node root;
    /**
     * The index of data objects.
     */
    private int[] index;

    /**
     * Constructs a tree out of the given n data data living in R^d.
     */
    public BBDTree(double[][] data) {
        int n = data.length;

        index = new int[n];
        for (int i = 0; i < n; i++) {
            index[i] = i;
        }

        // Build the tree
        root = buildNode(data, 0, n);
    }

    /**
     * Build a k-d tree from the given set of data.
     */
    private Node buildNode(double[][] data, int begin, int end) {
        int d = data[0].length;

        // Allocate the node
        Node node = new Node(d);

        // Fill in basic info
        node.count = end - begin;
        node.index = begin;

        // Calculate the bounding box
        double[] lowerBound = new double[d];
        double[] upperBound = new double[d];
        //初始化赋值 lower upper
        for (int i = 0; i < d; i++) {
            lowerBound[i] = data[index[begin]][i];
            upperBound[i] = data[index[begin]][i];
        }
        //找到data中最大、最小的组合
        for (int i = begin + 1; i < end; i++) {
            for (int j = 0; j < d; j++) {
                double c = data[index[i]][j];
                if (lowerBound[j] > c) {
                    lowerBound[j] = c;
                }
                if (upperBound[j] < c) {
                    upperBound[j] = c;
                }
            }
        }

        // Calculate bounding box stats
        double maxRadius = -1;
        int splitIndex = -1;
        for (int i = 0; i < d; i++) {
            node.center[i] = (lowerBound[i] + upperBound[i]) / 2;
            node.radius[i] = (upperBound[i] - lowerBound[i]) / 2;
            if (node.radius[i] > maxRadius) {
                maxRadius = node.radius[i];
                splitIndex = i;
            }
        }

        // If the max spread is 0, make this a leaf node
        if (maxRadius < 1E-10) {
            node.lower = node.upper = null;
            System.arraycopy(data[index[begin]], 0, node.sum, 0, d);

            if (end > begin + 1) {
                int len = end - begin;
                for (int i = 0; i < d; i++) {
                    node.sum[i] *= len;
                }
            }

            node.cost = 0;
            return node;
        }

        // Partition the data around the midpoint in this dimension. The
        // partitioning is done in-place by iterating from left-to-right and
        // right-to-left in the same way that partioning is done in quicksort.
        double splitCutoff = node.center[splitIndex];
        int i1 = begin, i2 = end - 1, size = 0;
        while (i1 <= i2) {
            boolean i1Good = (data[index[i1]][splitIndex] < splitCutoff);
            boolean i2Good = (data[index[i2]][splitIndex] >= splitCutoff);

            if (!i1Good && !i2Good) {
                int temp = index[i1];
                index[i1] = index[i2];
                index[i2] = temp;
                i1Good = i2Good = true;
            }

            if (i1Good) {
                i1++;
                size++;
            }

            if (i2Good) {
                i2--;
            }
        }

        // Create the child nodes
        node.lower = buildNode(data, begin, begin + size);
        node.upper = buildNode(data, begin + size, end);

        // Calculate the new sum and opt cost
        for (int i = 0; i < d; i++) {
            node.sum[i] = node.lower.sum[i] + node.upper.sum[i];
        }

        double[] mean = new double[d];
        for (int i = 0; i < d; i++) {
            mean[i] = node.sum[i] / node.count;
        }

        node.cost = getNodeCost(node.lower, mean) + getNodeCost(node.upper, mean);
        return node;
    }

    /**
     * Returns the total contribution of all data in the given kd-tree node,
     * assuming they are all assigned to a mean at the given location.
     * <p>
     * sum_{x \in node} ||x - mean||^2.
     * <p>
     * If c denotes the mean of mass of the data in this node and n denotes
     * the number of data in it, then this quantity is given by
     * <p>
     * n * ||c - mean||^2 + sum_{x \in node} ||x - c||^2
     * <p>
     * The sum is precomputed for each node as cost. This formula follows
     * from expanding both sides as dot products.
     *
     * 各维度方差之和,描述为稳定性
     */
    private double getNodeCost(Node node, double[] center) {
        int d = center.length;
        double scatter = 0.0;
        for (int i = 0; i < d; i++) {
            double x = (node.sum[i] / node.count) - center[i];
            scatter += x * x;
        }
        return node.cost + node.count * scatter;
    }

    /**
     * Given k cluster centroids, this method assigns data to nearest centroids.
     * The return value is the distortion to the centroids. The parameter sums
     * will hold the sum of data for each cluster. The parameter counts hold
     * the number of data of each cluster. If membership is
     * not null, it should be an array of size n that will be filled with the
     * index of the cluster [0 - k) that each data point is assigned to.
     */
    public double clustering(double[][] centroids, double[][] sums, int[] counts, int[] membership) {
        int k = centroids.length;

        Arrays.fill(counts, 0);
        int[] candidates = new int[k];
        for (int i = 0; i < k; i++) {
            candidates[i] = i;
            Arrays.fill(sums[i], 0.0);
        }

        return filter(root, centroids, candidates, k, sums, counts, membership);
    }

    /**
     * This determines which clusters all data that are rooted node will be
     * assigned to, and updates sums, counts and membership (if not null)
     * accordingly. Candidates maintains the set of cluster indices which
     * could possibly be the closest clusters for data in this subtree.
     */
    private double filter(Node node, double[][] centroids, int[] candidates, int k, double[][] sums, int[] counts, int[] membership) {
        int d = centroids[0].length;

        // Determine which mean the node mean is closest to
        double minDist = Math.squaredDistance(node.center, centroids[candidates[0]]);
        int closest = candidates[0];
        for (int i = 1; i < k; i++) {
            double dist = Math.squaredDistance(node.center, centroids[candidates[i]]);
            if (dist < minDist) {
                minDist = dist;
                closest = candidates[i];
            }
        }

        // If this is a non-leaf node, recurse if necessary
        if (node.lower != null) {
            // Build the new list of candidates
            int[] newCandidates = new int[k];
            int newk = 0;

            for (int i = 0; i < k; i++) {
                //存在至少1个最远的点center+raduis ,不属于最优中心点,则进行遍历,直到
                //所有点都属于最优的中心点
                if (!prune(node.center, node.radius, centroids, closest, candidates[i])) {
                    newCandidates[newk++] = candidates[i];
                }
            }

            // Recurse if there's at least two
            //因为prune中当test点==best点,增加了一个,所以为两个值
            if (newk > 1) {
                double result = filter(node.lower, centroids, newCandidates, newk, sums, counts, membership) + filter(node.upper, centroids, newCandidates, newk, sums, counts, membership);

                return result;
            }
        }

        // Assigns all data within this node to a single mean
        for (int i = 0; i < d; i++) {
            sums[closest][i] += node.sum[i];
        }

        counts[closest] += node.count;

        if (membership != null) {
            int last = node.index + node.count;
            for (int i = node.index; i < last; i++) {
                membership[index[i]] = closest;
            }
        }

        return getNodeCost(node, centroids[closest]);
    }

    /**
     * 当前遍历的kdtree中是否存在最远的点不属于中心点的簇,见公式
     * c是最优的中心点,c_0是其他的遍历的中心点
     * (x-c).(x-c) < (x-c_0).(x-c_0)  推导出 (c-c_0).(c-c_0) < 2(x-c_0).(c-c_0)
     *
     * Determines whether every point in the box is closer to centroids[bestIndex] than to
     * centroids[testIndex].
     * <p>
     * If x is a point, c_0 = centroids[bestIndex], c = centroids[testIndex], then:
     * (x-c).(x-c) < (x-c_0).(x-c_0)
     * <=> (c-c_0).(c-c_0) < 2(x-c_0).(c-c_0)
     * <p>
     * The right-hand side is maximized for a vertex of the box where for each
     * dimension, we choose the low or high value based on the sign of x-c_0 in
     * that dimension.
     */
    private boolean prune(double[] center, double[] radius, double[][] centroids, int bestIndex, int testIndex) {
        if (bestIndex == testIndex) {
            return false;
        }

        int d = centroids[0].length;

        double[] best = centroids[bestIndex];
        double[] test = centroids[testIndex];
        double lhs = 0.0, rhs = 0.0;
        for (int i = 0; i < d; i++) {
            double diff = test[i] - best[i];
            lhs += diff * diff;
            if (diff > 0) {
                rhs += (center[i] + radius[i] - best[i]) * diff;
            } else {
                rhs += (center[i] - radius[i] - best[i]) * diff;
            }
        }

        return (lhs >= 2 * rhs);
    }
}

迭代计算

KMeans(BBDTree bbd, double[][] data, int k, int maxIter) {
    if (k < 2) {
        throw new IllegalArgumentException("Invalid number of clusters: " + k);
    }

    if (maxIter <= 0) {
        throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
    }

    int n = data.length;
    int d = data[0].length;

    this.k = k;
    distortion = Double.MAX_VALUE;
    y = seed(data, k, ClusteringDistance.EUCLIDEAN);
    size = new int[k];
    centroids = new double[k][d];
    //中心centorids分组
    for (int i = 0; i < n; i++) {
        size[y[i]]++;
    }
    //中心分配具体的数据之和
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < d; j++) {
            centroids[y[i]][j] += data[i][j];
        }
    }
    //计算中心点 sum/count
    for (int i = 0; i < k; i++) {
        for (int j = 0; j < d; j++) {
            centroids[i][j] /= size[i];
        }
    }

    double[][] sums = new double[k][d];
    for (int iter = 1; iter <= maxIter; iter++) {
        double dist = bbd.clustering(centroids, sums, size, y);
        for (int i = 0; i < k; i++) {
            if (size[i] > 0) {
                for (int j = 0; j < d; j++) {
                    centroids[i][j] = sums[i][j] / size[i];
                }
            }
        }

        if (distortion <= dist) {
            break;
        } else {
            distortion = dist;
        }
    }
}

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习二-kmeans-kdtree 的相关文章

随机推荐

  • c++tuple和bitset

    tuple tuple是类似pair的模板 一个tuple可以有任意数量的成员 类型也可以不相同 make tuple 返回一个用给定初始值初始化的tuple 返回的tuple类型从初始值推断 可以使用 初始化但不能使用 tuple
  • Visual Studo Code & Anaconda环境配置

    在使用VScode的过程中 遇到如下问题 通过Anaconda安装的库文件在VScode中无法import 提示找不到该module 但在window的cmd中是可以import该库文件的 原因 环境配置问题 1 选择了错误的python解
  • 谈谈对于XSS跨站脚本攻击的学习(1)

    前言 最近学完XXE之后 对于这种恶意代码注入的漏洞提起来兴趣 想着现在正好趁热打铁 学习一下XSS 之前做题的时候看大师傅的wp一愣一愣的 不明白个所以然 这次系统的学习一下 在本文中将介绍有关XSS的知识点以及原理 也会介绍XSS的绕过
  • web前端期末大作业实例 (1500套) 集合

    文章目录 web前端期末大作业 1500套 集合 一 网页介绍 二 网页集合 表白网页 125套 集合 Echarts大屏数据展示 150套 集合 一 基于HTML Echarts技术制作 二 基于VUE Echarts技术制作 更多源码
  • 微信小程序腾讯位置服务添加不上去

    今天发现怎么添加都添加不上去 然后我搜索了大半天 emm最后都没有找到我的解决方法 然后终于我想起来他的报错提示 去搜了一下类目 然后去小程序添加了个交通类目 然后就好了
  • SOLO训练代码解析

    之前写过对SOLO demo的代码解析 今天来梳理一下training过程 首先是tools train py 这个文件是训练的开始 命令行运行的就是该文件 from mmdet apis import set random seed tr
  • Python使用国内镜像安装

    命令 pip install i 国内镜像地址 numpy 国内常用源镜像地址 清华 https pypi tuna tsinghua edu cn simple 阿里云 http mirrors aliyun com pypi simpl
  • JSON Web令牌(JWT)详解

    前言 今天要分享的知识是JWT 码字不易 转载请说明 目录 一 JWT出现的原因及工作原理 JWT是什么 为什么使用JWT JWT的工作原理 JWT组成 传统开发对资源的访问限制利用session完成图解 JWT所解决的问题及机制 JWT解
  • ReentrantLock 锁详解

    ReentrantLock 支持公平锁和非公平锁 可重入锁 ReentrantLock的底层是通过 AQS 链接 实现 一 BAT 大厂的面试题 1 什么是可重入 什么是可重入锁 它用来解决什么问题 2 ReentrantLock 的核心是
  • reduce和map的区别

    1 reduce 上代码 from functools import reduce sum1 reduce lambda x y x y range 1 5 print sum1 10 输出结果 10 结论 reduce返回的是函数经过执行
  • c语言05之从键盘输入一个整数加法表达式:操作数1+操作数2, 然后计算并输出表达式的计算结果, 形式如下:操作数1+操作数2=计算结果。

    题目 从键盘输入一个整数加法表达式 操作数1 操作数2 然后计算并输出表达式的计算结果 形式如下 操作数1 操作数2 计算结果 源代码 include
  • drool 7.x 语法和属性

    文章目录 1 语法 1 1 dialect 1 2 mvel 1 4 Eval 2 属性 2 1 salience 2 2 no loop 2 3 date effective 2 4 date expires 2 5 enabled 2
  • 在平时编程中感觉遇到的比较常见的问题

    文章目录 Linux Conda python tmux 如何debug Linux 查看当前文件夹中文件大小 ls lh Conda 在指定目录安装conda环境 conda create prefix mnt usrs xinrun c
  • (十)51单片机——利用蜂鸣器播放《孤勇者》(附成果展示)

    目录 硬件部分 蜂鸣器介绍 驱动电路 ULN2003 乐理部分 音符 音符与频率对照 代码部分 运行结果 随着最近孤勇者成为小朋友们的接头暗号之后 于是产生了利用单片机去播放孤勇者的想法 接下来我们来看看具体的实现以及效果展示吧 硬件部分
  • 遇到表明“Office 已检测到此文件存在问题。编辑此文件可能会损害您的计算机。“的解决方法

    文件验证失败 如果您在 受保护的视图 中遇到表明 Office 已检测到此文件存在问题 编辑此文件可能会损害您的计算机 请单击查看详细信息 的消息 发生此情况的原因是该文件未通过文件验证 下面是图像示例 您可以在 文件中检测到的问题 中了解
  • APS高级计划排程系统的基本原理和排程步骤

    APS高级计划与排程系统作为ERP和MES之间的桥梁 是承上启下的作用 用于协调物流 开发瓶颈资源和保证交货日期 APS系统包括需求和供应计划 运输和生产计划排程等各种供应链计划模块 下面主要介绍APS中生产计划排程模块的基本原理 APS系
  • linux线程及线程间通讯

    目录 一 线程 1 线程接口相关函数 1 创建线程 2 结束线程 3 等待线程 2 线程间通信 1 同步 2 互斥 一 线程 每一个进程的地址空间是相互独立的 每一个进程都有一个叫task struct任务结构体 在进行进程切换时需要不断刷
  • java.lang.NosuchMethodError:kotlin.collections.ArraysKt.copyInto

    在maven中加入
  • 关于联想台式机bios中虚拟化设置

    关于联想台式机bios中虚拟化设置 目前电脑更多的是从BIOS中查看是否支持虚拟化技术 当CPU支持VT x虚拟化技术的前提条件下 部分电脑将自动开启VT x功能 譬如IdeaPad Y450产品就采用了此设计 而大部分机型需要通过BIOS
  • 机器学习二-kmeans-kdtree

    机器学习纯java代码 点击打开链接 KD树介绍http www pelleg org shared hp kmeans html 我们的数据集也是从5高斯分布中随机生成的8000个点 你应该看到底层的Gaussians 蓝色边界表示 根