查找决策树中到决策边界的距离

2023-12-14

我想找到样本到经过训练的决策树分类器的决策边界的距离scikit学习。特征都是数字的,特征空间可以是任何大小。

到目前为止,我有一个基于示例 2D 案例的可视化here:

import numpy as np
import matplotlib.pyplot as plt

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_moons

# Generate some example data
X, y = make_moons(noise=0.3, random_state=0)

# Train the classifier
clf = DecisionTreeClassifier(max_depth=2)

clf.fit(X, y)

# Plot
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1))

Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, s=20, edgecolor='k')
plt.xlabel('a'); plt.ylabel('b');

enter image description here

据我所知,对于像 SVM 这样的其他分类器,这个距离可以通过数学计算 [1, 2, 3]。训练决策树后学到的规则定义了边界,也可能有助于通过算法计算距离[4, 5, 6]:

# Plot the trained tree
from sklearn import tree
import graphviz 
dot_data = tree.export_graphviz(clf, feature_names=['a', 'b'],  class_names=['1', '2'], filled=True)  
graph = graphviz.Source(dot_data)  

enter image description here


由于样本周围可能有多个决策边界,因此我假设这里的距离是指到最近决策边界的距离。

解决方案是递归树遍历算法。请注意,决策树不允许样本位于边界上,例如SVM,特征空间中的每个样本必须属于其中一个类。因此,在这里,我们将继续以小步骤修改样本的特征,每当这导致一个具有不同标签的区域(与最初由训练有素的分类器分配给样本的标签相比)时,我们就假设我们已经达到了决策边界。

详细来说,就像任何递归算法一样,我们有两种主要情况需要考虑:

  1. 基本情况,即我们位于叶节点。我们简单地检查当前样本是否具有不同的标签:如果是则返回它,否则返回None.
  2. 非叶节点。有两个分支机构,我们将样品发送给两个分支机构。我们不会修改样本以将其发送到它自然会采用的分支。但在将其发送到另一个分支之前,我们查看节点的(特征,阈值)对,并修改样本的给定特征,使其足以将其推到阈值的另一侧。

完整的Python代码:

def f(node,x,orig_label):
    global dt,tree
    if tree.children_left[node]==tree.children_right[node]: #Meaning node is a leaf
        return [x] if dt.predict([x])[0]!=orig_label else [None]

    if x[tree.feature[node]]<=tree.threshold[node]:
        orig = f(tree.children_left[node],x,orig_label)
        xc = x.copy()
        xc[tree.feature[node]] = tree.threshold[node] + .01
        modif = f(tree.children_right[node],xc,orig_label)
    else:
        orig = f(tree.children_right[node],x,orig_label)
        xc = x.copy()
        xc[tree.feature[node]] = tree.threshold[node] 
        modif = f(tree.children_left[node],xc,orig_label)
    return [s for s in orig+modif if s is not None]

这将返回给我们一个样本列表,这些样本会导致具有不同标签的叶子。我们现在需要做的就是取最近的一个:

dt =  DecisionTreeClassifier(max_depth=2).fit(X,y)
tree = dt.tree_
res = f(0,x,dt.predict([x])[0]) # 0 is index of root node
ans = np.min([np.linalg.norm(x-n) for n in res]) 

举例说明:

enter image description here

蓝色是原始样本,黄色是“在”决策边界上最近的样本。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

查找决策树中到决策边界的距离 的相关文章

随机推荐

  • 使用闭包来跟踪变量:好主意还是肮脏的伎俩?

    好的 我需要能够跟踪值类型对象 这些对象是另一个对象上的属性 如果这些属性不实现 IObservable 接口或类似接口 这是无法完成的 然后我想到了闭包和 Jon Skeet 的著名例子 以及如何多次打印 9 或 10 而不是按升序排列的
  • 如何使用 REST Api 从 salesforce 中的字段获取所有选项列表值?

    我正在尝试使用 REST API 从 salesforce 中的字段获取所有选项列表值 可以这样做吗 如果是的话那该怎么办呢 Thanks Raj 这很简单 您需要访问与此类似的资源 services data v26 0 sobjects
  • ANSI C:__DATE__ 和 __TIME__ 字符串大小的标准定义?

    ANSI C 中 DATE 和 TIME 字符串的大小是否有标准定义 这个问题背后的动机是 我有两个应用程序在两个不同的 CPU 上运行 在运行时 应用程序 1 从应用程序 2 接收日期和时间 作为版本信息的一部分 当然 应用程序 2 从预
  • Javascript 类中的方法链接[重复]

    这个问题在这里已经有答案了 我正在尝试在我的子类中实现方法链接 定位球 class Ball constructor name size power this name name this size size this power powe
  • 匿名方法 - 3 种不同的方式 - 异步

    不确定在标题中写什么 它们可能并不都是匿名方法 但这里是 假设我们有这个异步函数 public async Task Delete something 我正在使用 Blazor 服务器端 我对以下四种调用函数的方式感到好奇 假设它们位于 d
  • Android In App BIlling v3 - 错误的订阅试用期

    我正在使用 Android In App BIlling v3 库 当我调用 bp subscribe Activity subscriptionID 我获得了 Google Play 购买窗口 但计费周期始终为 每天 试用期始终为 1 天
  • Symfony2 Twig 无限子深度

    我有一个自连接表 其中每个文件夹都有一个父文件夹 并且其深度是无限的 一个文件夹可以有另一个文件夹作为父文件夹 没有深度限制 今天我的代码看起来像这样 我正在寻找一种根据需要深入挖掘的方法 而无需对每个步骤进行硬编码 是否有一种方法可以用循
  • Cpdf.php 第 3855 行中的 ErrorException:未定义索引:位于 barryvdh/laravel-dompdf

    我正在使用 laravel 5 2 dompdf 在本地主机上运行良好 但当移动到 AWS 时 它不断显示ErrorException in Cpdf php line 3855 Undefined index 在这一行 3855 中有字体
  • 带有 ssl 本地证书的 QNetworkRequest

    我需要与需要本地证书 crt 文件 的服务器交换数据 我试试这个 loginRequest QNetworkRequest QUrl https somesite com login QSslConfiguration sslConf lo
  • 打印特定类型的金字塔

    对于uni 我们必须打印特定类型的金字塔 这是代码 h 10 def build string pyramid s for i in range 1 h 1 print 1 end for j in range 2 i 1 print en
  • 致命错误:找不到类“Swift_smtpTransport”

    我正在尝试添加从我的网站后端向客户发送电子邮件的功能 并尝试使用 swiftmailer 来执行此操作 不幸的是 我不断收到错误消息 Fatal error Class Swift smtpTransport not found in ho
  • 使用翻译行为时如何查询翻译的内容?

    我的网站有多种语言 因此文章的标题取决于当地语言 但有一个问题 如何搜索另一种语言的文章 目前 唯一的方法是输入英文标题 以便 cakePHP 检索法文名称 我无法用法语搜索它 例如 当我搜索 Hello 时 我找到了名为 Bonjour
  • 如何从 Windows 剪贴板读取位图

    我正在编写一个非常小的 C 程序来帮助我制作精灵动画 我希望它能够获取从 Photoshop 复制到剪贴板的数据 在我的程序中对其进行操作 然后使用转换覆盖剪贴板 但问题是我不知道如何从 Photoshop 读取初始剪贴板 我可以加载剪贴板
  • 如何以编程方式或定期清除操作 PrintService 事件日志?

    我们正在尝试对在 Windows Server 2008 R2 上运行的打印机进行一些内部打印审核 通过事件查看器启用日志后 应用程序和服务日志 gt Microsoft gt Windows gt PrintService gt 操作 我
  • 如何从 weka API 计算置信度?

    我正在使用weka java API 在训练集上训练后我可以得到预测的类标签 双 pred fc classifyInstance test instance i 但我想知道类标签的置信概率 我应该使用什么函数 在 GUI 中 我可以选择将
  • 如何使 webpack 开发服务器在端口 80 和 0.0.0.0 上运行以使其可公开访问?

    我对整体是新的nodejs reactjs如果我的问题听起来很愚蠢 世界深表歉意 我目前正在玩反应性 js 每当我做一个npm start它总是继续运行localhost 8080 我如何将其更改为运行0 0 0 0 8080使其公开 我一
  • 检查线程是否是boost线程

    为了进行线程本地清理 我需要创建一个断言来检查当前线程是否是通过 boost thread 创建的 我怎样才能检查是否是这种情况 也就是说 如何检查当前线程是否由 boost thread 处理 我只需要在线程退出时清理线程本地存储 Boo
  • 获取消息:来自 AWS API 网关的禁止回复

    我正在尝试在 AWS 上创建 lambda 服务 并通过 API 网关从外部访问它 无需身份验证或限制 为了让事情变得简单 我现在将网关设置为模拟 在 API 的 Get 方法中 授权设置为NoneAPI 密钥是not required 当
  • 如何在插入工作时更新 BLOB 列,错误 ORA-00932

    我无法更新 BLOB 字段 但插入可以 请参阅下面的代码 我的猜测是 这与在大量记录中存储一个 BLOB 值的问题有关 涉及复制大数据 就我而言 我知道只会更新一条记录 但 Oracle 可能认为可能需要更新多条记录 使用插入时 可以保证只
  • 查找决策树中到决策边界的距离

    我想找到样本到经过训练的决策树分类器的决策边界的距离scikit学习 特征都是数字的 特征空间可以是任何大小 到目前为止 我有一个基于示例 2D 案例的可视化here import numpy as np import matplotlib