从 GradientBoostingClassifier 中提取决策规则

2023-12-11

我已经解决了以下问题:

如何提取梯度提升分类器的决策规则

如何从 scikit-learn 决策树中提取决策规则?

然而以上两个并没有解决我的目的。以下是我的查询:

我需要使用gradientboostingclassifer在Python中构建一个模型,并在SAS平台中实现该模型。为此,我需要从gradientboostingclassifer中提取决策规则。

以下是我迄今为止尝试过的:

在 IRIS 数据上构建模型:

# import the most common dataset
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO  
from IPython.display import Image

X, y = load_iris(return_X_y=True)
# there are 150 observations and 4 features
print(X.shape) # (150, 4)
# let's build a small model = 5 trees with depth no more than 2
model = GradientBoostingClassifier(n_estimators=5, max_depth=3, learning_rate=1.0)
model.fit(X, y==2) # predict 2nd class vs rest, for simplicity
# we can access individual trees
trees = model.estimators_.ravel()

def plot_tree(clf):
    dot_data = StringIO()
    export_graphviz(clf, out_file=dot_data, node_ids=True,
                    filled=True, rounded=True, 
                    special_characters=True)
    graph = pydotplus.graph_from_dot_data([enter image description here][3]dot_data.getvalue())  
    return Image(graph.create_png())

# now we can plot the first tree
plot_tree(trees[0])

绘制图表后,我检查了第一棵树的图表源代码并使用以下代码写入文本文件:

with open("C:\\Users\XXXX\Desktop\Python\input_tree.txt", "w") as wrt:
    wrt.write(export_graphviz(trees[0], out_file=None, node_ids=True,
                filled=True, rounded=True, 
                special_characters=True))

下面是输出文件:

digraph Tree {
node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;
edge [fontname=helvetica] ;
0 [label=<node &#35;0<br/>X<SUB>3</SUB> &le; 1.75<br/>friedman_mse = 0.222<br/>samples = 150<br/>value = 0.0>, fillcolor="#e5813955"] ;
1 [label=<node &#35;1<br/>X<SUB>2</SUB> &le; 4.95<br/>friedman_mse = 0.046<br/>samples = 104<br/>value = -0.285>, fillcolor="#e5813945"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label=<node &#35;2<br/>X<SUB>3</SUB> &le; 1.65<br/>friedman_mse = 0.01<br/>samples = 98<br/>value = -0.323>, fillcolor="#e5813943"] ;
1 -> 2 ;
3 [label=<node &#35;3<br/>friedman_mse = 0.0<br/>samples = 97<br/>value = -1.5>, fillcolor="#e5813900"] ;
2 -> 3 ;
4 [label=<node &#35;4<br/>friedman_mse = -0.0<br/>samples = 1<br/>value = 3.0>, fillcolor="#e58139ff"] ;
2 -> 4 ;
5 [label=<node &#35;5<br/>X<SUB>3</SUB> &le; 1.55<br/>friedman_mse = 0.222<br/>samples = 6<br/>value = 0.333>, fillcolor="#e5813968"] ;
1 -> 5 ;
6 [label=<node &#35;6<br/>friedman_mse = 0.0<br/>samples = 3<br/>value = 3.0>, fillcolor="#e58139ff"] ;
5 -> 6 ;
7 [label=<node &#35;7<br/>friedman_mse = 0.222<br/>samples = 3<br/>value = 0.0>, fillcolor="#e5813955"] ;
5 -> 7 ;
8 [label=<node &#35;8<br/>X<SUB>2</SUB> &le; 4.85<br/>friedman_mse = 0.021<br/>samples = 46<br/>value = 0.645>, fillcolor="#e581397a"] ;
0 -> 8 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
9 [label=<node &#35;9<br/>X<SUB>1</SUB> &le; 3.1<br/>friedman_mse = 0.222<br/>samples = 3<br/>value = 0.333>, fillcolor="#e5813968"] ;
8 -> 9 ;
10 [label=<node &#35;10<br/>friedman_mse = 0.0<br/>samples = 2<br/>value = 3.0>, fillcolor="#e58139ff"] ;
9 -> 10 ;
11 [label=<node &#35;11<br/>friedman_mse = -0.0<br/>samples = 1<br/>value = -1.5>, fillcolor="#e5813900"] ;
9 -> 11 ;
12 [label=<node &#35;12<br/>friedman_mse = -0.0<br/>samples = 43<br/>value = 3.0>, fillcolor="#e58139ff"] ;
8 -> 12 ;
}

为了从输出文件中提取决策规则,我尝试将以下 python RegEX 代码转换为 SAS 代码:

 import re
with open("C:\\Users\XXXX\Desktop\Python\input_tree.txt") as f:
    with open("C:\\Users\XXXX\Desktop\Python\output.txt", "w") as f1:
        result0 = 'value = 0;'
        f1.write(result0)
        for line in f:
            result1 = re.sub(r'^(\d+)\s+.*<br\/>([A-Z]+)<SUB>(\d+)<\/SUB>\s+(.+?)([-\d.]+)<br\/>friedman_mse.*;$',r"if \2\3 \4 \5 then do;",line)
            result2 = re.sub(r'^(\d+).*(?!SUB).*(value\s+=)\s([-\d.]+).*;$',r"\2 value + \3; end;",result1)
            result3 = re.sub(r'^(\d+\s+->\s+\d+\s+);$',r'\1',result2)
            result4 = re.sub(r'^digraph.+|^node.+|^edge.+','',result3)
            result5 = re.sub(r'&(\w{2});',r'\1',result4)
            result6 = re.sub(r'}','end;',result5)
            f1.write(result6)

以下是上述代码的输出 SAS:

value = 0;
if X3 le  1.75 then do;
if X2 le  4.95 then do;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
if X3 le  1.65 then do;
1 -> 2 
value = value + -1.5; end;
2 -> 3 
value = value + 3.0; end;
2 -> 4 
if X3 le  1.55 then do;
1 -> 5 
value = value + 3.0; end;
5 -> 6 
value = value + 0.0; end;
5 -> 7 
if X2 le  4.85 then do;
0 -> 8 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
if X1 le  3.1 then do;
8 -> 9 
value = value + 3.0; end;
9 -> 10 
value = value + -1.5; end;
9 -> 11 
value = value + 3.0; end;
8 -> 12 
end;

正如您所看到的,输出文件中缺少一块,即我无法正确打开/关闭 do-end 块。为此,我需要使用节点号,但我失败了,因为我无法在这里找到任何模式。

你们中的任何人都可以帮我解决这个问题吗?

除此之外,像决策树分类器一样,我可以不提取上面第二个链接中提到的children_left、children_right、阈值吗?我已经成功提取了GBM的每棵树

trees = model.estimators_.ravel()

但我没有找到任何有用的函数可以用来提取每棵树的值和规则。如果我可以以类似于 DecisionTreeclassifier 的方式使用 grapviz 对象,请提供帮助。

OR

帮助我使用任何其他可以解决我的目的的方法。


无需使用 graphviz 导出来访问决策树数据。model.estimators_包含模型组成的所有单独分类器。对于 GradientBoostingClassifier,这是一个形状为 (n_estimators, n_classes) 的 2D numpy 数组,每一项都是 DecisionTreeRegressor。

每个决策树都有一个属性_tree and 了解决策树结构展示了如何从该对象中获取节点、阈值和子对象。


import numpy
import pandas
from sklearn.ensemble import GradientBoostingClassifier

est = GradientBoostingClassifier(n_estimators=4)
numpy.random.seed(1)
est.fit(numpy.random.random((100, 3)), numpy.random.choice([0, 1, 2], size=(100,)))
print('s', est.estimators_.shape)

n_classes, n_estimators = est.estimators_.shape
for c in range(n_classes):
    for t in range(n_estimators):
        dtree = est.estimators_[c, t]
        print("class={}, tree={}: {}".format(c, t, dtree.tree_))

        rules = pandas.DataFrame({
            'child_left': dtree.tree_.children_left,
            'child_right': dtree.tree_.children_right,
            'feature': dtree.tree_.feature,
            'threshold': dtree.tree_.threshold,
        })
        print(rules)

为每棵树输出类似这样的内容:

class=0, tree=0: <sklearn.tree._tree.Tree object at 0x7f18a697f370>
   child_left  child_right  feature  threshold
0           1            2        0   0.020702
1          -1           -1       -2  -2.000000
2           3            6        1   0.879058
3           4            5        1   0.543716
4          -1           -1       -2  -2.000000
5          -1           -1       -2  -2.000000
6           7            8        0   0.292586
7          -1           -1       -2  -2.000000
8          -1           -1       -2  -2.000000
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

从 GradientBoostingClassifier 中提取决策规则 的相关文章

随机推荐

  • 在python中求一棵树的最大和

    我有一棵数字树 我希望能够找到数字之和 每个数字下面是左右两个孩子 在所有可能的路径中 我希望能够通过所有可能的路径找到最大的数字 这是一个例子 8 3 11 10 2 32 6 返回 8 11 32 51 我觉得这是一个递归问题 但我坚持
  • 快速卷积算法

    我需要对两个一维信号进行卷积 一个平均有 500 个点 这个是汉宁窗函数 另一个有 125000 个点 每次运行 我需要应用三倍的卷积运算 我已经有一个基于 scipy 文档运行的实现 如果您愿意 您可以在此处查看代码 前面是 Delphi
  • Pentaho数据集成Kettle转换中如何配置生产环境的数据库连接

    我设计了一个ktr文件进行转换 我需要配置生产环境的数据库连接详细信息 我怎样才能做到这一点 有什么建议么 我使用环境变量 KETTLE HOME KETTLE JNDI ROOT PATH PATH KETTLE HOME Kettle
  • Boost Signals2 自动连接管理和更改信号的互斥类型

    我正在尝试使用自动连接管理 and 更改信号的互斥类型对于模板函数 以下代码使用 gcc 4 3 4 可以正常编译和执行 http ideone com LLN6d include
  • Renci SSH.NET - 没有为 opmnctl 返回结果字符串

    我试图通过 VB NET 从命令获取结果 它返回空字符串 如下面的代码所示 Dim connInfo As New Renci SshNet PasswordConnectionInfo serverip user pass Dim ssh
  • IonRangeSlider 将标签分配给值

    我在用离子范围滑块我想为值分配标签 反之亦然 因此 用户可以通过以下选项选择距海滩的距离 on beach 100m 200m 300m more than 300m 但我需要post价值观像 0 100 200 300 999 My in
  • 通过一次导入 csv 将多个用户添加到多个组(后续查询)

    我一直在寻找一种使用多个用户名填充多个通讯组的方法 我碰到本网站上的脚本由成员 Frode F 编写 Import Csv C Scripts Import Bulk Users into bulk groups bulkgroups3 c
  • 如果没有明确设置一个巨大的常量值,是否可以期望被告知不要超时?

    我已将超时设置为一个愚蠢的高数字 有没有更好的方法告诉脚本不要超时 usr bin expect spawn telnet 10 10 10 10 set timeout 200000000 expect login send user r
  • iOS 15 中 UIButton 图像行为发生变化?

    我的代码很简单 我有一个 UIButton 的出口 button 我在代码中设置它的图像 let jack UIImage named jack png self button setImage jack for normal 问题是这并不
  • Cython C++ 包装器运算符() 重载错误

    与我之前的问题有关 使用 Cython 包装使用 OpenCV 类型作为参数的 C 类 现在我陷入了另一个错误 我的 OpenCV 类型 Matx33d 的 cython 包装代码如下所示 cdef extern from opencv2
  • Eclipse:运行时我们如何获取 main 参数

    在 Java 中 对于普通的 main 方法 public static void main String args code here String args用于从命令行获取一些参数 我可以通过以下方式从命令提示符运行此文件 javac
  • 为什么修改迭代变量不影响后续迭代?

    这是我遇到问题的 Python 代码 for i in range 0 10 if i 5 i 3 print i 我预计输出是 0 1 2 3 4 8 9 然而 翻译却吐槽道 0 1 2 3 4 8 6 7 8 9 我知道一个for循环在
  • 在 matlab 等高线图中选择特定水平

    我有这个plot我生成它是为了测试等值线图在 matlab 上的工作原理 我想弄清楚是否有一种方法可以只绘制其中一条线 但不一定是第一条线 Matlab 的解释是 如果你这样做 contour X Y Z 1 它会绘制其中一条线 但它始终是
  • Oracle SQL 对版本号进行排序

    在 Oracle 中 只需使用ORDER BY不对版本号进行排序 我的Version Number字段被声明为VARCHAR我无法改变它 例如 以下版本 1 20 1 9 1 18 1 13 1 5 1 11 2 0 1 8 1 3 1 2
  • 使用SFTP / RCurl创建远程目录

    是否可以使用 RCurl 包在 SFTP 站点上创建目录 我找到了sftp create dirs函数 但我找不到如何使用它的示例 我尝试设置ftp create missing dirs选项TRUE as in library RCurl
  • JavaScript 提升函数与函数变量

    这是我的 JavaScript 代码 console log a c b var a Hello World var b function console log B is called function c console log C i
  • AttributeError:构建逻辑回归模型时“str”对象没有属性“decode”[重复]

    这个问题在这里已经有答案了 我正在尝试建立一个逻辑回归模型 但它显示了AttributeError str object has no attribute decode 请帮我解决这个问题 该代码在 Datacamp 的服务器上完美运行 但
  • ValueError:解析日期时时间数据与格式不匹配

    当我尝试将字符串解析为日期时间时 我这样做 之前已导入日期时间 fecha 2 datetime strptime 22 01 2019 17 00 d m y H M 但是 我收到此错误 ValueError 时间数据 22 01 201
  • PHP 方法链接的好处?

    仍在 PHP OOP 训练轮上 这个问题可能属于失败博客网站 PHP 中的方法链有什么好处 我不确定这是否重要 但我将静态调用我的方法 例如 foo Bar get sysop gt set admin gt render 根据我的阅读 任
  • 从 GradientBoostingClassifier 中提取决策规则

    我已经解决了以下问题 如何提取梯度提升分类器的决策规则 如何从 scikit learn 决策树中提取决策规则 然而以上两个并没有解决我的目的 以下是我的查询 我需要使用gradientboostingclassifer在Python中构建