决策树

2023-10-26

这篇博客用来简要介绍决策树算法(DecisionTree)

决策树是机器学习中常用的一种算法,它即可用于解决分类问题,也可用于解决回归问题,在这篇博客我们只介绍分类决策树。

决策树顾名思义是一种树形结构,而我们的任务就是想办法构建出这样一颗树用它来进行分类。

在开始介绍决策树的构建之前,首先介绍几个相关概念,信息熵条件熵以及信息增益

信息熵

我个人的理解,信息熵就是用来衡量一个随机变量取值的不确定性的一个指标,信息熵越大则不确定性越大,信息熵越小则不确定性也就越小。

假设一个随机变量X的概率分布如下:

X

v_1

v_2 .... v_n
p p_1 p_2 .... p_n

则随机变量X的信息熵的计算公式如下:

H(X)=\sum _{i=1}^{n}-p_ilogp_i

通常情况下对数以2为底或以e(自然对数)为底,并且我们规定如果p_i=0则定义0log0=0

一个服从两点分布的随机变量的信息熵图像如下图所示:

可以看出当概率p=0.5时,信息熵最大,随机变量的不确定性也就最大,从直观的角度这也很容易理解,即如果随机变量属于两个值的概率都相等,那我们很难确定它属于哪一个值,反之如果不相等,则我们将猜测随机变量属于概率大的哪一个值。

条件熵

设二维随机变量(X, Y)的概率分布为:

p(X=x_i, Y=y_i)=p_{ij}, \;\;i=1, 2, 3, ..., n;\;\;j=1, 2, 3, ..., m

p_i=p(X=x_i)

我们将随机变量X给定条件下随机变量Y的条件熵表示为H(Y|X),表示在已知随机变量X的情况下随机变量Y的不确定性,其定义如下:

H(Y|X)=\sum _{i=1}^{n}p_iH(Y|X=x_i)

实际上就是给定X的条件下随机变量Y的熵对于X的期望值。

信息增益

信息增益表示为g(Y, X),其含义为给定X的能够使随机变量Y的确定性增加的程度。

计算方式如下:

g(Y, X) = H(Y) - H(Y|X)

这个很好理解,当我们知道某些信息时肯定是要比什么都不知道的时候更能确定某个随机变量的取值,除非这些信息是无效信息,否则肯定会带来确定性的增加,即信息增益。

接下来我们就来探究决策树构建过程

决策树由节点和有向边所构成,节点分为叶子节点和内部节点两部分,通常情况下,节点分支都为二叉的,如下图所示:

图中圆形为内部节点,方块为叶子节点。

我们以案例驱动的方式来解释决策树是如何构建出来的,其过程分为两步,特征选择以及决策树的生成

首先我们给定一张数据表,数据表中记录的是一些贷款信息,如下图所示:

贷款申请样本数据表
ID 年龄 有工作 有房子 信贷情况 类别
1 青年 一般
2 青年
3 青年
4 青年 一般
5 青年 一般
6 中年 一般
7 中年
8 中年
9 中年 非常好
10 中年 非常好
11 老年 非常好
12 老年
13 老年
14 老年 非常好
15 老年 一般

我们的任务就是构建一颗决策树来进行判断是否同意某个人的贷款申请。

我们用随机变量Y来表示类别(0表示否,1表示是), A表示年龄(0表示青年,1表示中年,2表示老年),W是否有工作(0表示否,1表示是),R表示是否有房子(0表示否,1表示是),C表示信贷情况(0表示一般,1表示好,2表示非常好)

根节点包含的样本:所有样本

根节点信息熵:

H(Y) =- \frac{6}{15}log\frac{6}{15} - \frac{9}{15}log\frac{9}{15}\approx 0.97

根节点各个特征的条件熵:

H(Y|A)=p(A=0)H(Y|A=0)+p(A=1)H(Y|A=1)+p(A=2)H(Y|A=2)=\frac{1}{3}(-\frac{2}{5}log\frac{2}{5}-\frac{3}{5}log\frac{3}{5})+\frac{1}{3}(-\frac{2}{5}log\frac{2}{5}-\frac{3}{5}log\frac{3}{5})+\frac{1}{3}(-\frac{1}{5}log\frac{1}{5}-\frac{4}{5}log\frac{4}{5})\approx 0.89

H(Y|W)=p(W=0)H(Y|W=0)+p(W=1)H(Y|W=1)=\frac{10}{15}(-\frac{6}{10}log\frac{6}{10}-\frac{4}{10}log\frac{4}{10})+\frac{5}{15}(-1log1 - 0log0)\approx 0.65

H(Y|R)=p(R=0)H(Y|R=0)+p(R=1)H(Y|R=1)=\frac{9}{15}(-\frac{6}{9}log\frac{6}{9}-\frac{3}{9}log\frac{3}{9})+\frac{6}{15}*0\approx 0.55

H(Y|C)=p(C=0)H(Y|C=0)+p(C=1)H(Y|C=1)+p(C=2)H(Y|C=2)\approx 0.61

计算信息增益:

g(Y, A)=H(Y) - H(Y|A)=0.08

g(Y, W)=H(Y) - H(Y|W)=0.32

g(Y, R)=H(Y) - H(Y|R)=0.42

g(Y, C)=H(Y) - H(Y|C)=0.36

因此可以确定,当分支特征为R(是否有房子)时带来的信息增益越大,因此根节点的分支特征选择为是否有房子,分为左右两支,左子节点中的数据集为没有房子的样本,右子节点的数据集为有房子,之后以此类推便构建出了决策树模型。

综上,每个节点的分支过程可以分为以下几步:

1.确定当前节点的信息熵以及各个特征的条件熵

2.计算各个特征的信息增益

3.确定当前节点的分支特征

那么,我们如何决定什么时候停止分支呢?我们可以这样设置,当某个节点的信息熵小于某个阈值时我们就停止对这个节点的分支操作,那么此节点也就成为了叶子节点。

最终我们需要确定每个叶子结点的类别,即叶子结点中的样本集中,占比最大的那一个类别便是当前叶子节点的类别,当新来一个样本我们只需要按照决策树从顶层向下逐步判断,看样本最终落入那个叶子结点,所落入的叶子结点的类别便是当前样本的预测类别。

至此决策树的基本原理已经介绍完毕,要说明的是,以上介绍的只是决策树当中的ID3算法,在ID3算法提出之后,相继又提出了C4.5算法和CART,与ID3算法不同的是,C4.5算法采用的特征选择指标是信息增益率而不是信息增益,其他原理同ID3相同。实际上信息增益率的计算方法也相当简单,就是信息增益与特征的熵的比值;而CART(Class and Regression Tree)从名字看出来即可做分类又可做回归,与ID3不同的是,作为分类树时,其选择划分节点的特征的依据是基尼系数,选择使得划分后数据集基尼系数最小的那个特征作为节点的划分的特征。如果是构建回归树,则划分节点的特征的选择是依据划分后子节点的方差,我们尽可能选择使得方差最小的划分特征,我们往往使用启发式的方式来搜索划分特征以及划分特征的值,即依次遍历各个维度的特征,当遍历到当前维度特征时,依次遍历当前特征的所有值作为划分值从而选出使得子节点方差最小的划分特征以及特征的值。

以下为本人使用Python编写的决策树的相关程序:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import accuracy_score


class DecesionTree(object):

    def __init__(self, stop_entropy, deepth):
        """

        :param stop_entropy: 终止分裂的最小熵值
        :param deepth: 决策树最大深度
        """
        self.stop_entropy = stop_entropy
        self.deepth = deepth
        data = load_iris()
        x = data["data"]
        y = data["target"]
        self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(x, y, random_state=2, test_size=0.3)
        self.tree = {}
        for layer in range(deepth):
            self.tree[layer] = [None] * 2 ** layer
        attr_max_values = np.max(self.x_train, axis=0)
        attr_min_values = np.min(self.x_train, axis=0)
        # self.split_values用来记录每一维特征的所有可能的分割值, self.split[i]表示第i维特征的所有可能的分割值
        self.split_values = []
        for min_value, max_value in zip(attr_min_values, attr_max_values):
            self.split_values.append(np.linspace(min_value, max_value, 102, endpoint=True)[1:-1])

    def generate_tree(self):
        for layer in range(self.deepth):
            # 遍历是第几层节点
            # node_num用于标记当前层节点的个数
            node_num = 2 ** layer
            for node_index in range(1, node_num + 1):
                # 先确定当前节点的数据集,标记
                node = Node()
                if layer == 0:
                    # 如果是第0层,则当前节点node的数据集就是训练集
                    node.data_set = self.x_train
                    node.labels = self.y_train
                else:
                    # 如果不是0层,则需要根据父节点来获取当前节点的数据集
                    # previous_index为生成当前节点的上一层父节点的索引
                    previous_index = int(np.ceil(node_index / 2)) - 1
                    previous_node = self.tree[layer - 1][previous_index]
                    if (previous_node is None) or (previous_node.attr_split is None):
                        # 判断如果父节点为叶子节点则跳过当前节点的构造
                        continue
                    attr_split = previous_node.attr_split
                    split_value = previous_node.split_value
                    data_set = previous_node.data_set
                    labels = previous_node.labels
                    if node_index % 2 != 0:
                        # 上一层节点的左分支(小于分割值)
                        data_index = data_set[:, attr_split] < split_value
                        node.data_set = data_set[data_index]
                        node.labels = labels[data_index]
                    else:
                        # 上一层节点的右分支(大于等于分割值)
                        data_index = data_set[:, attr_split] >= split_value
                        node.data_set = data_set[data_index]
                        node.labels = labels[data_index]
                # 确定当前节点数据集后,计算当前节点信息熵,确定是否进行分割,如果小于停止分割信息熵,则不分割,否则分割
                current_node_entropy = self.calc_entropy(node.labels)
                if current_node_entropy < self.stop_entropy:
                    # 小于停止分割信息熵则不对当前节点的attr_split以及split_value赋值,当前节点成为叶子节点
                    self.tree[layer][node_index - 1] = node
                    continue
                else:
                    # 确定当前节点的分割特征以及分割特征的值
                    courrent_entropy_gain = 0
                    current_attr_split = None
                    current_split_value = None
                    for attr_split_index in range(4):
                        # 遍历四个维度的特征
                        for split_value in self.split_values[attr_split_index]:
                            entropy_gain = self.calc_entropy_gain(attr_split_index, split_value, node.data_set, node.labels)
                            if entropy_gain >= courrent_entropy_gain:
                                courrent_entropy_gain = entropy_gain
                                current_attr_split = attr_split_index
                                current_split_value = split_value
                    node.attr_split = current_attr_split
                    node.split_value = current_split_value
                self.tree[layer][node_index - 1] = node
        values = list(self.tree.values())
        for value in values:
            if np.all(np.logical_not(np.array(value, np.bool))):
                index_ = values.index(value)
                break
        else:
            index_ = self.deepth
        nodes = self.tree[index_ - 1]
        for node in nodes:
            if isinstance(node, Node):
                node.attr_split = None
                node.split_value = None

    def calc_entropy(self, labels):
        """

        :param labels: 要计算熵值的数据集的标记
        :return: 数据集的熵值
        """
        # 计算信息熵
        label_unique = np.unique(labels)
        labels_list = list(labels)
        p_list = []
        for label in label_unique:
            p_list.append(labels_list.count(label) / len(labels_list))
        entropy = 0
        for p in p_list:
            entropy -= p * np.log2(p)
        return entropy

    def calc_condition_entropy(self, attr_split, split_value, data_set, labels):
        """

        :param attr_split: 用来分割数据集的特征
        :param split_value: 用来分割数据集的特征的值
        :param data_set: 被分割的数据集
        :param labels: 被分割的数据集标记
        :return: 条件熵, 大于分割特征特征值的数据集和标记以及小于分割特征特征值的数据集和标记
        """
        # 计算小于分割值的数据集的熵
        smaller_index = data_set[:, attr_split] < split_value
        smaller_data_set = data_set[smaller_index]
        smaller_label = labels[smaller_index]
        smaller_entropy = self.calc_entropy(smaller_label)
        # 计算大于分割值的数据集的熵
        bigger_index = data_set[:, attr_split] >= split_value
        bigger_data_set = data_set[bigger_index]
        bigger_label = labels[bigger_index]
        bigger_entropy = self.calc_entropy(bigger_label)
        # 两部分熵加权求和得到条件熵
        p_smaller = len(smaller_label) / len(labels)
        p_bigger = len(bigger_label) / len(labels)
        condition_entropy = p_smaller * smaller_entropy + p_bigger * bigger_entropy
        return condition_entropy, bigger_data_set, bigger_label, smaller_data_set, smaller_label

    def calc_entropy_gain(self, attr_split, split_value, data_set, labels):
        entropy = self.calc_entropy(labels)
        condition_entropy = self.calc_condition_entropy(attr_split, split_value, data_set, labels)[0]
        entropy_gain = entropy - condition_entropy
        return entropy_gain

    def leaf_node_label(self):
        # 确定叶子节点的类别
        nodes = []
        nodes_list = self.tree.values()
        for nd in nodes_list:
            nodes.extend(nd)
        while True:
            if None in nodes:
                nodes.remove(None)
            else:
                break
        for node in nodes:
            if not node.split_value:
                labels = np.unique(node.labels)
                p = []
                for label in labels:
                    p.append(list(node.labels).count(label) / len(node.labels))
                node.predict_label = labels[p.index(max(p))]

    def test(self):
        self.y_test_pred = []
        for test_sample in self.x_test:
            current_nodes = [self.tree[0][0]]
            for layer in range(1, self.deepth):
                current_attr_split = current_nodes[-1].attr_split
                current_split_value = current_nodes[-1].split_value
                # 寻找下一个子节点
                if test_sample[current_attr_split] < current_split_value:
                    current_node = self.tree[layer][self.tree[layer - 1].index(current_nodes[-1]) * 2]
                else:
                    current_node = self.tree[layer][self.tree[layer - 1].index(current_nodes[-1]) * 2 + 1]
                current_nodes.append(current_node)
                if not current_node.split_value:
                    # 如果已经到叶节点就不在继续
                    break
            y_pred = current_nodes[-1].predict_label
            self.y_test_pred.append(y_pred)
        print("鸢尾花数据测试集预测结果:")
        print("测试集预测类别:", self.y_test_pred)
        print("测试集真实类别:", list(self.y_test))
        print("预测准确率:%.2f%s" % (100 * accuracy_score(self.y_test, self.y_test_pred), "%"))


class Node(object):

    def __init__(self):
        # self.data_set为当前节点的样本
        self.data_set = None
        # self.labels为当前节点的样本对应的标记
        self.labels = None
        # self.attr_split为用来分割当前节点的特征是哪一维
        self.attr_split = None
        # self.split_value为分割当前节点特征的分割值是多少
        self.split_value = None
        # 如果当前节点为叶子节点,则self.predict_label代表当前叶节点的类别
        self.predict_label = None


def main():
    tree = DecesionTree(0.2, 3)
    tree.generate_tree()
    tree.leaf_node_label()
    tree.test()


if __name__ == "__main__":
    main()

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

决策树 的相关文章

  • TCP协议

    1 TCP协议基本概念 RFCs 793 1122 1323 2018 2581 1 1 TCP协议的特点 点对点 一个发送方 一个接收方 可靠的 按序的字节流 可靠的 多种确保可靠性的机制 字节流服务 8bit 1Byte 为最小单位构成
  • 好的软件架构设计

    什么是软件架构 前言 软体设计师中有一些技术水平较高 经验较为丰富的人 他们需要承担软件系统的架构设计 也就是需要设计系统的元件如何划分 元件之间如何发生相互作用 以及系统中逻辑的 物理的 系统的重要决定的作出 在很多公司中 架构师不是一个
  • 在VS中配置VTK

    VTK与VS结合使用有两种配置方式 一种是配置cmake 一种是配置VS 两者配置一者即可 我这里只讲解一下配置VS的步骤吧 我用的是VS2010 打开VS 找到属性 在VC 目录中修改包含目录和库目录为自己安装VTK的include和li

随机推荐

  • 华为OD机试真题-最多等和不相交连续子序列【2023Q2】【JAVA、Python、C++】

    题目描述 给定一个数组 我们称其中连续的元素为连续子序列 称这些元素的和为连续子序列的和 数组中可能存在几组连续子序列 组内的连续子序列互不相交且有相同的和 求一组连续子序列 组内子序列的数目最多 输出这个数目 输入描述 第一行输入为数组长
  • 怎么用计算机算lnx,ln计算(log计算器在线)

    ln MN lnM lnN ln M N lnM lnN ln M n nlnM ln1 0 lne 1 注意 拆开后 M N需要大于0 没有 ln M N lnM lnN 和ln M N lnM lnN lnx 是e x的反函数 也就是说
  • 华为设置android系统提醒功能,华为手机短信不提醒怎么办?华为手机短信提醒设置方法...

    华为手机短信提醒设置方法 1 检查当前设置的默认短信应用是哪个应用 点击桌面 设置 图标 找到 应用程序管理 选择 默认应用设置 选择 信息 可以看到当前正在使用的默认短信应用名称 如果使用的是第三方短信应用 请将 信息 勾选 改为使用默认
  • 制造蝴蝶飓风,微众区块链的蝶变和ESG新使命

    时间来到新世纪 共同繁荣 人与自然和谐发展等成为全球共识的背景下 越来越多的国家和组织开始践行ESG 环境 社会和公司治理 理念 在中国 乡村振兴 共同富裕 双碳战略 数字经济等国家级战略的推出 也旨在推动 效率优先 的发展模式 向公平与可
  • 生成openVPN客户端配置的shell脚本

    脚本介绍 在服务端 etc openvpn 目录下存放该脚本 client sh 脚本运行 运行方式 client sh client name client name填你想输入的客户端名称 例如 输入yes 和相应的ca密钥 运行成功如图
  • Python调用WebServer(WSDL)注意事项

    本人很少与WebServer交互 最近调用公司SAP的同步人员信息 发现一些很小的点 但是很浪费时间的注意事项 第一 不要去相信对方开发者嘴中所谓的JSON 很有可能是各种非标准JSON 这是非标准JSON 至于标准的 键带有双引号的 re
  • 关于域控DC不能正常同步GC的解决办法(域控时间超过墓碑时间) 与域控SRV记录

    现象 用户两台域控 GC PDC 上面创建用户DC不能正常同步 DC上面创建用户GC能够同步 同时发现有一台文件服务器有些机器不能正常访问 提示共享无权限 原因 用dcdiag命令在GC上没有问题 在DC上发现墓碑时间问题 可以确定是墓碑时
  • UnityHub打不开自己的项目的一个可能

    自己的unity项目前几天还一切正常 突然就打不开了 从unity跳转不到hub 从hub点项目转了几圈就没反应了 也没办法新建项目 看了网上很多解决方法 重新登录 没反应 删了unityhub重新下载 没反应 关闭防火墙重新插usb接口这
  • 关于Typora初次下载输入代码时代码行号不显示的问题

    关于Typora初次下载输入代码时代码行号不显示的问题 我刚用Typora的时候 打开代码块发现居然不显示行号 以下是我打开代码块内行号的显示的步骤 我刚用Typora的时候 打开代码块发现居然不显示行号 以下是我打开代码块内行号的显示的步
  • dataframe按照某一列的取值进行拆分

    dataframe按照某一列 假设列名为 columnname 的取值进行拆分 即 比如dataframe的第一列只有 a b 两种取值可能 就把dataframe拆分成两个小的dataframe 一个dataframe的第一列只取 a 另
  • 【WiFi】Hostapd工作流程分析

    目录 1 Hostapd概述 2 Hostapd代码框架 3 Hostapd各种命令配置工具 4 hostaod的主函数 5 hostaod代码分析 1 Hostapd概述 Hostapd是一个运行在用户态的守护进程 可以通过Hostapd
  • JavaScript 入门基础 - 流程控制(四)

    JavaScript 流程控制 分支和循环 文章目录 JavaScript 流程控制 分支和循环 1 什么是流程控制 2 顺序流程控制 3 分支流程控制 之 if语句 3 1 什么是分支结构 3 2 if 语句 3 2 1 if 语句基本理
  • IPV6组播地址

    1 IPV6组播地址 RFC4291定义组播地址格式如下 8 4 4 112 11111111 flgs scop group ID
  • nas文件服务器web接口,nas配置web服务器

    nas配置web服务器 内容精选 换一换 通过Web浏览器登录资源 会话页面载入失败 提示由于服务器长时间无响应 连接已断开 请检查您的网络并重试 Code T 514 云堡垒机系统与资源服务器之间网络连接不稳定 导致连接断开 云堡垒机系统
  • 计蒜客 蒜头君的新游戏(DP)

    蒜头君的新游戏 include
  • 构造函数设置为private,会怎样。

    构造函数设置为private 会怎样 1 无法静态的创建对象了 即不能通过 A a这种方式创建对象了 只能通过在类的内部的静态成员函数中new一个对象 动态的创建对象 include
  • NotScripts扩展在Chrome中禁用网页JavaScript

    经常上网查找资料的朋友 应该对于那些无法复制网页内容的网站是深有感触的 由于这些网站作者为保护自己的网站内容不被他人抄袭 使用了JavaScrip来禁用鼠标右键复制功能 解决办法当然就是用浏览器禁止使用网页的JS加载或者生效了 如果你经常使
  • Hive窗口函数大全

    Hive窗口函数 一 偏移量函数 lag lead 二 窗口分析函数 first value last value 三 排序函数 rank dense rank row number 一 偏移量函数 lag 语法 lag col n def
  • linux网络编程实现投票功能

    投票系统 1 说明 写了一个投票系统 过程是先配置好服务器 在写一个网上投票功能 要实现网上投票功能 其实功能实现还是很简单的 麻烦一点的在于过程比较繁杂 要做的东西还是挺多的 2 过程 第一步 配置httpd服务器 先配置好httpd服务
  • 决策树

    这篇博客用来简要介绍决策树算法 DecisionTree 决策树是机器学习中常用的一种算法 它即可用于解决分类问题 也可用于解决回归问题 在这篇博客我们只介绍分类决策树 决策树顾名思义是一种树形结构 而我们的任务就是想办法构建出这样一颗树用