Python中的决策树机器学习模型简要介绍和代码示例(基于sklearn)

发布于:2025-08-04 ⋅ 阅读:(11) ⋅ 点赞:(0)

一、决策树定义

决策树是一种监督学习算法,可用于**分类(Classification)回归(Regression)**任务。

它的结构类似树状结构:

  • 内部节点:特征条件(如X > 2
  • 叶子节点:输出类别或数值
  • 路径:对应一系列条件组合

二、决策树的基本概念

1. 信息熵(Entropy)

衡量样本集合的不确定性:

H(D)=−∑k=1Kpklog⁡2pk H(D) = - \sum_{k=1}^{K} p_k \log_2 p_k H(D)=k=1Kpklog2pk

其中:

  • DDD:样本集合
  • pkp_kpk:类别 kkk 的概率

2. 信息增益(Information Gain)

衡量某特征对信息熵的降低程度:

IG(D,A)=H(D)−∑v=1V∣Dv∣∣D∣H(Dv) IG(D, A) = H(D) - \sum_{v=1}^{V} \frac{|D^v|}{|D|} H(D^v) IG(D,A)=H(D)v=1VDDvH(Dv)

  • DvD^vDv:按特征 AAA 值划分的子集
  • 常用于 ID3 算法

3. 信息增益率(Gain Ratio)

用于 C4.5 算法,避免信息增益偏好取值多的特征:

GainRatio(D,A)=IG(D,A)IV(A) \text{GainRatio}(D, A) = \frac{IG(D, A)}{IV(A)} GainRatio(D,A)=IV(A)IG(D,A)

  • IV(A)=−∑v=1V∣Dv∣∣D∣log⁡2∣Dv∣∣D∣IV(A) = -\sum_{v=1}^V \frac{|D^v|}{|D|} \log_2 \frac{|D^v|}{|D|}IV(A)=v=1VDDvlog2DDv

4. Gini系数(Gini Impurity)

CART 分类算法使用的分裂标准:

Gini(D)=1−∑k=1Kpk2 Gini(D) = 1 - \sum_{k=1}^{K} p_k^2 Gini(D)=1k=1Kpk2

越小表示纯度越高。


5. 均方误差(MSE)

用于决策树回归:

MSE=1N∑i=1N(yi−y^)2 MSE = \frac{1}{N} \sum_{i=1}^N (y_i - \hat{y})^2 MSE=N1i=1N(yiy^)2

其中 y^\hat{y}y^ 是某叶子节点上的预测值。


三、决策树的算法流程(以分类为例)

  1. 选择最优划分特征

    • 使用信息增益 / 信息增益率 / Gini 系数
  2. 划分数据集

    • 递归构建子树
  3. 停止条件

    • 数据已纯净 / 特征用尽 / 达到最大深度
  4. 剪枝(可选)

    • 预剪枝 / 后剪枝,防止过拟合

四、实际示例

我们用一个简单的例子说明:

天气 温度 湿度 打球
正常
正常
正常

目标是预测“是否打球”。


五、代码实现

示例使用 Python + scikit-learn 实现

(1)决策树分类 + 可视化

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# 加载数据
X, y = load_iris(return_X_y=True)

# 创建模型
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
clf.fit(X, y)

# 预测
y_pred = clf.predict(X[:5])
print("预测结果:", y_pred)

# 可视化
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, feature_names=load_iris().feature_names, class_names=load_iris().target_names)
plt.show()

(2)决策树回归 + 可视化

from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt

# 模拟数据
X = np.sort(5 * np.random.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])

# 模型训练
reg = DecisionTreeRegressor(max_depth=4)
reg.fit(X, y)

# 可视化预测
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_pred = reg.predict(X_test)

plt.figure(figsize=(10, 6))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_pred, color="cornflowerblue", label="prediction")
plt.legend()
plt.title("Decision Tree Regression")
plt.show()

六、超参数 & 控制项

参数 说明
criterion 划分标准(gini, entropy, squared_error
max_depth 最大深度
min_samples_split 内部节点最小样本数
min_samples_leaf 叶子节点最小样本数
max_features 用于划分的特征数

七、剪枝技巧

决策树容易过拟合训练数据,特别是当树结构过深时。剪枝用于控制模型复杂度,提高泛化能力。


1、剪枝类型概览

类型 说明
预剪枝(Pre-Pruning) 在构建过程中限制深度、叶子样本数等
后剪枝(Post-Pruning) 构建完决策树后,自底向上裁剪不必要的分支

本内容聚焦在 后剪枝实现


2、后剪枝算法思想

基本流程:

  1. 从叶子节点向上回溯每个子树
  2. 判断当前子树(有分支) vs 将其剪成叶子节点谁的准确率更高(或损失更小)
  3. 若剪枝后效果更好 ⇒ 剪枝(用子树上样本的多数类替代)

3、后剪枝代码实现(基于自己实现的决策树)

下面是一个简洁的 ID3 决策树分类 + 后剪枝的完整实现:

1. 数据准备
from collections import Counter

# 简化示例数据
data = [
    ['晴', '高', '弱', '否'],
    ['晴', '高', '强', '否'],
    ['阴', '高', '弱', '是'],
    ['雨', '中', '弱', '是'],
    ['雨', '低', '弱', '是'],
    ['雨', '低', '强', '否'],
    ['阴', '低', '强', '是']
]

# 特征名称
features = ['天气', '温度', '风']

2. 构建决策树(ID3)
def entropy(data):
    labels = [row[-1] for row in data]
    counter = Counter(labels)
    total = len(data)
    return -sum((count/total) * (count/total).bit_length() for count in counter.values())

def split_dataset(data, axis, value):
    return [row[:axis] + row[axis+1:] for row in data if row[axis] == value]

def choose_best_feature(data):
    base_entropy = entropy(data)
    best_gain = 0
    best_feature = -1
    num_features = len(data[0]) - 1

    for i in range(num_features):
        values = set(row[i] for row in data)
        new_entropy = 0
        for val in values:
            subset = split_dataset(data, i, val)
            prob = len(subset) / len(data)
            new_entropy += prob * entropy(subset)
        gain = base_entropy - new_entropy
        if gain > best_gain:
            best_gain = gain
            best_feature = i
    return best_feature

def majority_class(data):
    labels = [row[-1] for row in data]
    return Counter(labels).most_common(1)[0][0]

def build_tree(data, features):
    labels = [row[-1] for row in data]
    if labels.count(labels[0]) == len(labels):
        return labels[0]
    if len(features) == 0:
        return majority_class(data)

    best_feat = choose_best_feature(data)
    best_feat_name = features[best_feat]
    tree = {best_feat_name: {}}

    feat_values = set(row[best_feat] for row in data)
    sub_features = features[:best_feat] + features[best_feat+1:]

    for val in feat_values:
        subset = split_dataset(data, best_feat, val)
        tree[best_feat_name][val] = build_tree(subset, sub_features)
    return tree

3. 后剪枝实现
def classify(tree, features, sample):
    if not isinstance(tree, dict):
        return tree
    root = next(iter(tree))
    sub_tree = tree[root]
    idx = features.index(root)
    value = sample[idx]
    subtree = sub_tree.get(value)
    if not subtree:
        return None
    return classify(subtree, features, sample)

def accuracy(tree, features, data):
    correct = 0
    for row in data:
        if classify(tree, features, row[:-1]) == row[-1]:
            correct += 1
    return correct / len(data)

def prune_tree(tree, features, data):
    if not isinstance(tree, dict):
        return tree

    root = next(iter(tree))
    idx = features.index(root)
    new_tree = {root: {}}

    for val, subtree in tree[root].items():
        subset = [row for row in data if row[idx] == val]
        if not subset:
            new_tree[root][val] = subtree
        else:
            pruned_subtree = prune_tree(subtree, features[:idx] + features[idx+1:], split_dataset(data, idx, val))
            new_tree[root][val] = pruned_subtree

    # 尝试剪枝为单叶节点
    flat_labels = [row[-1] for row in data]
    majority = majority_class(data)

    # 原树精度
    original_acc = accuracy(new_tree, features, data)
    # 剪枝后精度(所有预测为多数类)
    pruned_acc = flat_labels.count(majority) / len(flat_labels)

    if pruned_acc >= original_acc:
        return majority
    else:
        return new_tree

4. 使用示例
# 构建树
tree = build_tree(data, features)
print("原始决策树:", tree)

# 后剪枝
pruned = prune_tree(tree, features, data)
print("剪枝后树:", pruned)

4、剪枝效果展示(示意)

原始决策树:
{'天气': {
    '雨': {'风': {'弱': '是', '强': '否'}},
    '阴': '是',
    '晴': {'风': {'弱': '否', '强': '否'}}
}}

剪枝后树:
{'天气': {
    '雨': '是',
    '阴': '是',
    '晴': '否'
}}

八、优缺点总结

优点:

  • 易理解,树结构直观
  • 可处理分类与回归
  • 可解释性强

缺点:

  • 容易过拟合
  • 对小变化敏感
  • 对连续变量划分不如 ensemble 方法鲁棒


网站公告

今日签到

点亮在社区的每一天
去签到