决策树技术详解:从理论到Python实战

发布于:2025-08-11 ⋅ 阅读:(13) ⋅ 点赞:(0)

决策树像人类的思考过程,用一系列“是/否”问题层层逼近答案

目录

一、决策树的核心本质

二、决策树的核心构成

三、决策树的数学原理

四、算法对比

五、决策树的双面性

六、Python代码实战

tree包有什么?

鸢尾花数据集(Iris Dataset)

1. DecisionTreeClassifier(决策树分类器)

基本用法

关键参数

1. 树的结构控制

2. 分裂策略

3. 正则化与防过拟合

4. 随机性与权重

gini、entropy与经典算法ID3、C4.5、CART的关系

1. 基尼系数(Gini) → CART 算法

2. 信息增益(Entropy) → ID3/C4.5 的启发

3. scikit-learn 的决策树实现本质是 CART

为什么 scikit-learn 选择 CART 框架?

如何选择 gini 或 entropy?

总结

2. DecisionTreeRegressor(决策树回归器)

基本用法

关键参数

3. ExtraTreeClassifier(极端随机树分类器)

基本用法

关键区别

4. ExtraTreeRegressor(极端随机树回归器)

基本用法

关键区别

5. 可视化决策树

使用 plot_tree(推荐)

使用 export_text(文本形式)

总结


一、决策树的核心本质

决策树是一种模仿人类决策过程的树形结构分类/回归模型。它通过节点(问题)​​ 和 ​边(答案)​ 构建路径,最终在叶节点(决策结果)​输出预测值。这种白盒模型的优势在于极高的可解释性


二、决策树的核心构成

  1. 根节点​:初始特征划分点
  2. 内部节点​:特征测试点(每个节点对应一个判断条件)
  3. 分支​:判断条件的可能结果
  4. 叶节点​:最终决策结果(分类/回归值)

关键概念​:

  • 纯度(Purity)​​:节点内样本类别的统一程度(Gini指数/熵越小越纯)
  • 信息增益(Information Gain)​​:分裂后纯度的提升量
  • 剪枝(Pruning)​​:防止过拟合的关键技术(预剪枝/后剪枝)

三、决策树的数学原理

决策树通过递归分割寻找最优特征:

1、​选择分裂特征​:

  • ID3算法​:使用信息增益​(缺陷:偏好多值特征)
  • C4.5算法​:改进为增益率​(消除特征取值数量的影响)
  • CART算法​:使用Gini指数​(计算效率更高) 

​2、停止条件​:

  • 节点样本全属同一类
  • 特征已用完
  • 样本数低于阈值(超参数控制)

四、算法对比


参考:

算法 原生设计 能否用于回归 原因/限制
ID3 分类(信息增益) ❌ 不能直接使用 依赖离散标签的熵计算,回归任务是连续值,无法直接计算类别纯度。
C4.5 分类(增益率) ❌ 不能直接使用 同ID3,分裂标准基于分类熵,且要求离散特征。
CART 分类+回归 ✅ 可直接使用 设计时同时支持Gini指数(分类)和最小方差(回归),天然兼容连续值目标变量。
维度 分类树(ID3/C4.5/CART)​ 回归树(CART)​
分裂标准 信息增益、增益率、Gini指数 最小化方差(MSE)或绝对误差(MAE)
叶节点输出 类别标签 连续值(均值/中位数)
特征类型 离散特征(ID3/C4.5)或混合(CART) 支持连续和离散特征

ID3和C4.5原生仅支持分类,但通过替换分裂标准和离散化连续特征,可间接适配回归任务。实际应用中,​CART是更高效且通用的选择


五、决策树的双面性

​优势 ✅:

  • 直观可视化(业务人员可理解)
  • 无需数据标准化
  • 支持混合特征(数值+类别)

局限 ⚠️​:

  • 对数据扰动敏感(小变动可能导致结构剧变)
  • 容易过拟合(必须剪枝
  • 不适合学习复杂关系(如异或问题

延展思考​:决策树作为集成学习的基模型(如随机森林/XGBoost)时,通过“群体智慧”能极大克服自身缺陷。在实际应用中,超过80%的预测场景会优先尝试树模型家族。

六、Python代码实战

tree包有什么?

from sklearn.tree import DecisionTreeClassifier, plot_tree

看一下tree包有些什么?

"""Decision tree based models for classification and regression."""
# 模块文档字符串,说明这个模块提供基于决策树的分类和回归模型

# Authors: The scikit-learn developers
# 标明作者是 scikit-learn 开发团队

# SPDX-License-Identifier: BSD-3-Clause
# 软件许可证声明,使用 BSD 3-Clause 许可证

from ._classes import (
    BaseDecisionTree,  # 决策树的基类(抽象类)
    DecisionTreeClassifier,  # 决策树分类器
    DecisionTreeRegressor,  # 决策树回归器
    ExtraTreeClassifier,  # 极端随机树分类器
    ExtraTreeRegressor,  # 极端随机树回归器
)
# 从当前目录的 _classes.py 模块导入决策树相关类

from ._export import export_graphviz, export_text, plot_tree
# 从当前目录的 _export.py 模块导入可视化导出函数:
# - export_graphviz: 导出Graphviz格式的可视化
# - export_text: 导出文本形式的决策规则
# - plot_tree: 绘制决策树图形

__all__ = [
    "BaseDecisionTree",  # 公开的基类
    "DecisionTreeClassifier",  # 公开的分类器
    "DecisionTreeRegressor",  # 公开的回归器
    "ExtraTreeClassifier",  # 公开的极端随机树分类器
    "ExtraTreeRegressor",  # 公开的极端随机树回归器
    "export_graphviz",  # 公开的Graphviz导出函数
    "plot_tree",  # 公开的绘图函数
    "export_text",  # 公开的文本导出函数
]
# 定义模块的公开API,当使用 from sklearn.tree import * 时,只有这里列出的名称会被导入

关键点说明:

  1. 模块结构分为两部分:

    • _classes.py 包含决策树的核心实现类

    • _export.py 包含可视化相关工具函数

  2. 提供的5个核心类:

    • 基类 BaseDecisionTree 包含通用实现(BaseDecisionTree 是 scikit-learn 中所有决策树模型的基类(抽象基类),它定义了决策树的核心框架和通用方法,但普通用户通常不会直接使用它。它的主要用途是作为 DecisionTreeClassifierDecisionTreeRegressor 等具体实现类的父类,提供共享的逻辑和接口。除非你要自定义一种新的决策树变体(例如实现一种新的分裂准则),否则不需要直接使用 BaseDecisionTree。)

    • 标准决策树和极端随机树(ExtraTrees)两种变体

    • 每种变体都有分类和回归版本

  3. 可视化工具:

    • 支持图形化(plot_tree)和文本(export_text)两种展示方式

    • 支持导出到Graphviz格式(export_graphviz)用于进一步处理

  4. __all__ 严格控制了模块的公开接口

鸢尾花数据集(Iris Dataset)

鸢尾花数据集是机器学习领域最经典的数据集之一,由英国统计学家和生物学家 ​Ronald Fisher​ 在1936年提出,常用于分类算法的入门和测试。

  • 样本数量​:150 个样本(3 类 × 50 个样本)
  • 特征数量​:4 个数值型特征
  • 目标类别​:3 种鸢尾花品种
特征(Feature) 描述
花萼长度(sepal length) 单位:cm
花萼宽度(sepal width) 单位:cm
花瓣长度(petal length) 单位:cm
花瓣宽度(petal width) 单位:cm
类别(Target) 描述
-------------- ------
Setosa(山鸢尾) 线性可分,容易分类
Versicolor(杂色鸢尾) 与 Virginica 部分重叠
Virginica(维吉尼亚鸢尾) 与 Versicolor 部分重叠
from sklearn.datasets import load_iris
import pandas as pd

iris = load_iris()
X = iris.data  # 特征矩阵 (150, 4)
y = iris.target  # 类别标签 (0, 1, 2)
feature_names = iris.feature_names  # 特征名称
target_names = iris.target_names  # 类别名称

# 转为DataFrame(可选)
df = pd.DataFrame(X, columns=feature_names)
df['species'] = [target_names[label] for label in y]
print(df.head())


在 sklearn.tree 模块中,提供了几种不同的决策树模型,适用于分类和回归任务。


1. DecisionTreeClassifier(决策树分类器)

适用于分类问题(预测离散类别标签)。

基本用法

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树分类器
clf = DecisionTreeClassifier(
    max_depth=3,       # 树的最大深度(防止过拟合)
    criterion="gini",  # 分裂标准:"gini"(基尼系数)或 "entropy"(信息增益)
    random_state=42,   # 随机种子(确保结果可复现)
)

# 训练模型
clf.fit(X_train, y_train)

# 预测
y_pred = clf.predict(X_test)

# 评估准确率
accuracy = clf.score(X_test, y_test)
print(f"测试集准确率: {accuracy:.2f}") #测试集准确率: 1.00

关键参数

  • max_depth:树的最大深度(控制过拟合)

  • criterion:分裂标准("gini" 或 "entropy",默认"gini" )

  • min_samples_split:节点分裂所需的最小样本数

  • min_samples_leaf:叶子节点所需的最小样本数

  • random_state:随机种子(确保结果可复现)

所有参数如下

 def __init__(
        self,
        *,
        criterion="gini",
        splitter="best",
        max_depth=None,
        min_samples_split=2,
        min_samples_leaf=1,
        min_weight_fraction_leaf=0.0,
        max_features=None,
        random_state=None,
        max_leaf_nodes=None,
        min_impurity_decrease=0.0,
        class_weight=None,
        ccp_alpha=0.0,
        monotonic_cst=None,
    )
1. 树的结构控制
参数 默认值 作用
max_depth None 树的最大深度。None 表示不限制,直到所有叶子节点纯净或达到 min_samples_split防止过拟合的关键参数
max_leaf_nodes None 最大叶子节点数。优先调整 max_depth,此参数作为补充限制。
min_samples_split 2 节点分裂所需的最小样本数。若样本数 < 此值,则不再分裂。
min_samples_leaf 1 叶子节点所需的最小样本数。分裂后子节点样本数必须 ≥ 此值。
min_weight_fraction_leaf 0.0 叶子节点样本权重和的最小占比(加权数据时使用)。

2. 分裂策略
参数 默认值 作用
criterion "gini"

分裂质量的衡量标准:
- 分类:"gini"(基尼系数)、 "entropy"(信息增益)
- 回归:"squared_error"(MSE)"friedman_mse" 、 "absolute_error"(MAE)

splitter "best" 分裂策略:
"best":选择最优分裂
"random":随机选择分裂(更快的训练,适合 ExtraTree
max_features None 寻找最优分裂时考虑的最大特征数:
None:全部特征
"sqrt":√(总特征数)
"log2":log₂(总特征数)
- 整数/浮点数:直接指定数量/比例

3. 正则化与防过拟合
参数 默认值 作用
min_impurity_decrease 0.0 分裂的最小不纯度减少量。若分裂后不纯度减少 < 此值,则停止分裂。
ccp_alpha 0.0 代价复杂度剪枝的 α 参数(≥0)。值越大,剪枝越激进。
monotonic_cst None 单调性约束(高级功能),强制预测值随特征单调变化。

4. 随机性与权重
参数 默认值 作用
random_state None 随机种子,控制特征/分裂的随机选择(确保结果可复现)。
class_weight None 类别权重:
None:所有类别权重=1
"balanced":自动按类别频率反比加权
- 字典:手动指定类别权重(如 {0: 0.5, 1: 1.0}

gini、entropy与经典算法ID3、C4.5、CART的关系

在 scikit-learn 的决策树实现中,criterion="gini"(基尼系数)和 criterion="entropy"(信息增益)的选择与经典算法(ID3、C4.5、CART)的关系如下:


1. 基尼系数(Gini) → CART 算法
  • 对应算法:CART(Classification and Regression Trees,分类与回归树)

  • 特点

    • 基尼系数是 CART 算法默认的分裂标准(用于分类任务)。

    • 计算更高效(无需对数运算),但结果通常与信息增益非常接近。


2. 信息增益(Entropy) → ID3/C4.5 的启发
  • 对应算法:ID3 和 C4.5 使用信息增益(或增益比),但 scikit-learn 并未完全实现 C4.5

  • 特点

    • 信息增益基于信息熵,计算稍慢(涉及对数运算)。

    • 注意:scikit-learn 的 entropy 仅实现信息增益,未实现 C4.5 的增益比(Gain Ratio),因此不完全等同于 C4.5。


3. scikit-learn 的决策树实现本质是 CART

无论选择 gini 还是 entropy,scikit-learn 的底层实现均基于 CART 框架,与经典算法的主要区别如下:

特性 CART (scikit-learn) ID3 C4.5
分裂标准 基尼系数或信息增益 信息增益 增益比(Gain Ratio)
任务类型 分类 + 回归 仅分类 仅分类
特征类型 支持连续和离散特征 仅离散特征 支持连续和离散
二叉树/多叉树 二叉树(总是二元分裂) 多叉树 多叉树
缺失值处理 内置支持 不支持 支持
剪枝方式 代价复杂度剪枝(CCP) 悲观剪枝

为什么 scikit-learn 选择 CART 框架?
  1. 统一性:CART 同时支持分类和回归任务(ID3/C4.5 仅支持分类)。

  2. 效率:二叉树结构比多叉树更高效,适合大规模数据。

  3. 灵活性:支持连续特征和缺失值处理(无需像 ID3 那样预处理离散化)。


如何选择 gini 或 entropy
  • 基尼系数(Gini):计算更快(推荐默认使用)。对类别分布不均匀的数据更鲁棒。

  • 信息增益(Entropy):理论更贴近信息论。可能生成更平衡的树(但对性能影响通常很小)。

实际应用中,两者的分类效果通常差异不大,优先选择 gini(除非有特定需求)。


总结
  • criterion="gini" → CART 算法的标准实现。

  • criterion="entropy" → 借鉴了 ID3/C4.5 的思想,但仍在 CART 框架下运行。

  • scikit-learn 没有完整实现 ID3/C4.5(如多叉树、增益比等功能),其决策树本质是 CART 的优化版本。


2. DecisionTreeRegressor(决策树回归器)

适用于回归问题(预测连续值)。

基本用法

from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error

# 加载数据
housing = fetch_california_housing()
X, y = housing.data, housing.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树回归器
reg = DecisionTreeRegressor(
    max_depth=4,          # 树的最大深度
    criterion="squared_error",  # 分裂标准(MSE)
    random_state=42,
)

# 训练模型
reg.fit(X_train, y_train)

# 预测
y_pred = reg.predict(X_test)

# 评估(均方误差 MSE)
mse = mean_squared_error(y_test, y_pred)
print(f"测试集均方误差: {mse:.2f}")

关键参数

  • criterion:分裂标准("squared_error"(MSE)、"friedman_mse" 或 "absolute_error"(MAE))

  • 其他参数与 DecisionTreeClassifier 类似(max_depthmin_samples_split 等)


3. ExtraTreeClassifier(极端随机树分类器)

与 DecisionTreeClassifier 类似,但分裂时随机选择特征和阈值(更随机化,训练更快,可能泛化更好)。

基本用法

from sklearn.tree import ExtraTreeClassifier

# 创建极端随机树分类器
clf = ExtraTreeClassifier(
    max_depth=3,
    criterion="gini",
    random_state=42,
)

# 训练和预测(与 DecisionTreeClassifier 相同)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(f"测试集准确率: {clf.score(X_test, y_test):.2f}")

关键区别

  • 分裂时随机选择特征和阈值(比普通决策树更随机)

  • 训练速度更快,但可能牺牲一些准确率


4. ExtraTreeRegressor(极端随机树回归器)

与 DecisionTreeRegressor 类似,但分裂时随机选择特征和阈值

基本用法

from sklearn.tree import ExtraTreeRegressor

# 创建极端随机树回归器
reg = ExtraTreeRegressor(
    max_depth=4,
    criterion="squared_error",
    random_state=42,
)

# 训练和预测(与 DecisionTreeRegressor 相同)
reg.fit(X_train, y_train)
y_pred = reg.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f"测试集均方误差: {mse:.2f}")

关键区别

  • 分裂时随机选择特征和阈值

  • 训练更快,适用于大数据集


5. 可视化决策树

可以使用 plot_tree 或 export_graphviz 可视化决策树。

使用 plot_tree(推荐)

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 8))
plot_tree(
    clf,  # 训练好的决策树模型
    feature_names=iris.feature_names,  # 特征名
    class_names=iris.target_names,     # 类别名
    filled=True,                       # 填充颜色
    rounded=True,                      # 圆角节点
)
plt.show()

输出示例


使用 export_text(文本形式)

from sklearn.tree import export_text

tree_rules = export_text(
    clf,
    feature_names=iris.feature_names,
)
print(tree_rules)

输出示例

|--- median_income <= 5.03
|   |--- ocean_proximity_INLAND <= 0.50
|   |   |--- median_income <= 3.11
|   |   |   |--- median_income <= 2.21
|   |   |   |   |--- truncated branch of depth 21
|   |   |   |--- median_income >  2.21
|   |   |   |   |--- truncated branch of depth 20
|   |   |--- median_income >  3.11
|   |   |   |--- longitude <= -118.31
|   |   |   |   |--- truncated branch of depth 24
|   |   |   |--- longitude >  -118.31
|   |   |   |   |--- truncated branch of depth 22
|   |--- ocean_proximity_INLAND >  0.50
|   |   |--- median_income <= 3.04
|   |   |   |--- median_income <= 2.22
|   |   |   |   |--- truncated branch of depth 17
|   |   |   |--- median_income >  2.22
|   |   |   |   |--- truncated branch of depth 19
|   |   |--- median_income >  3.04
|   |   |   |--- median_income <= 4.07
|   |   |   |   |--- truncated branch of depth 16
|   |   |   |--- median_income >  4.07
|   |   |   |   |--- truncated branch of depth 14
|--- median_income >  5.03
|   |--- median_income <= 6.87
|   |   |--- ocean_proximity_INLAND <= 0.50
|   |   |   |--- housing_median_age <= 36.50
|   |   |   |   |--- truncated branch of depth 18
|   |   |   |--- housing_median_age >  36.50
|   |   |   |   |--- truncated branch of depth 11
|   |   |--- ocean_proximity_INLAND >  0.50
|   |   |   |--- housing_median_age <= 32.50
|   |   |   |   |--- truncated branch of depth 11
|   |   |   |--- housing_median_age >  32.50
|   |   |   |   |--- truncated branch of depth 5
|   |--- median_income >  6.87
|   |   |--- median_income <= 8.16
|   |   |   |--- housing_median_age <= 27.50
|   |   |   |   |--- truncated branch of depth 12
|   |   |   |--- housing_median_age >  27.50
|   |   |   |   |--- truncated branch of depth 8
|   |   |--- median_income >  8.16
|   |   |   |--- total_bedrooms <= 33.00
|   |   |   |   |--- truncated branch of depth 2
|   |   |   |--- total_bedrooms >  33.00
|   |   |   |   |--- truncated branch of depth 9

总结

模型 适用任务 关键特点 示例场景
DecisionTreeClassifier 分类 标准决策树 鸢尾花分类
DecisionTreeRegressor 回归 标准决策树 房价预测
ExtraTreeClassifier 分类 更随机的分裂 高维数据分类
ExtraTreeRegressor 回归 更随机的分裂 大数据回归
  • 默认用 DecisionTreeClassifier/Regressor(更稳定)

  • 数据量大时用 ExtraTree(训练更快)

  • 防止过拟合:调整 max_depthmin_samples_split 等参数


网站公告

今日签到

点亮在社区的每一天
去签到