1 引言
上一篇文章发表日期是7月27日。考虑到我在8月休了半个月婚假,能在有限的空闲时间里抽空学习并沉淀自己的认知内容,已让我感到十分满足。
当然,能在9月中旬完成这篇文章,除了个人的“努力”之外,内容本身不算复杂也是一个重要原因。本篇文章的核心是“因果森林”。对于具备一定机器学习基础的读者而言,“因果森林”这一名称很容易让人联想到“随机森林”。事实上,两者之间确实存在诸多相似之处。将它们进行对比,有助于更深入地理解“因果森林”的算法原理。
正文如下。
2 因果树
正如学习随机森林之前需先掌握决策树,介绍因果森林前也需先理解因果树。
2.1 算法原理
因果树的概念提出于2016年,主要用于评估异质性因果效应(heterogeneous causal effects)。简单来说,如果某一处理(treatment)对所有样本的效果一致,则为同质性因果效应,例如接种疫苗后所有人感染概率均下降90%;而如果处理在不同子群体或特征下效果不同,则为异质性因果效应,例如广告在不同地区的转化率提升幅度不同。
一个简单的因果树结构如下:原始样本集 S S S的平均因果效应为 τ S \tau_S τS。通过某种分割规则,形成左叶节点 S L S_L SL和右叶节点 S R S_R SR,对应的因果效应分别为 τ L \tau_L τL和 τ R \tau_R τR。
每个分组的CATE计算方式为:
τ = Y 1 − Y 0 \tau = Y_1-Y_0 τ=Y1−Y0
其中, Y 1 Y_1 Y1和 Y 0 Y_0 Y0分别表示分组内处理组和对照组的平均结果值。
S L S_L SL和 S R S_R SR代表不同子群体, τ L \tau_L τL和 τ R \tau_R τR的差异越大,说明树的划分越有效。划分优劣可通过以下指标衡量:
Q = n L ⋅ ( τ L − τ S ) 2 + n R ⋅ ( τ R − τ S ) 2 Q = n_L ·(\tau_L-\tau_S)^2 + n_R ·(\tau_R-\tau_S)^2 Q=nL⋅(τL−τS)2+nR⋅(τR−τS)2
Q Q Q值越大,说明分割带来的处理效应差异越显著,分割更“有意义”。
算法原理如上,较为简明。下表对比了决策树(用于分类)与因果树:
决策树 | 因果树 | |
---|---|---|
目标 | 分类 | 估计处理效应 |
分割标准 | 最小化不纯度gini | 最大化 Q Q Q |
叶节点内容 | 分类结果 | CATE |
适用场景 | 分类预测 | 异质性效应估计 |
在数据集划分方面,决策树和因果树也有所不同:
- 决策树:通常划分为训练集和测试集,训练集用于模型训练,测试集用于评估泛化能力。
- 因果树:通常划分为分割集和估计集,前者用于确定分割规则,后者用于计算CATE,这种方法称为“Honest approaches”。
2.2 实例计算
为加深对因果树算法的理解,下面通过具体实例手动演示划分过程。
假设有一份医疗实验数据集,包含以下信息:
- 特征:年龄(Age)、性别(Gender)
- 处理:是否服用新药(Treatment, 1=服用,0=未服用)
- 结果:血压下降值(Outcome, 单位mmHg)
我们希望利用因果树发现不同人群对新药的降压效果。数据共20条,明细如下:
ID | Age | Gender | Treatment | Outcome |
---|---|---|---|---|
1 | 22 | 0 | 1 | 13 |
2 | 33 | 1 | 0 | 7 |
3 | 41 | 0 | 1 | 10 |
4 | 27 | 1 | 0 | 6 |
5 | 38 | 0 | 1 | 12 |
6 | 45 | 1 | 1 | 9 |
7 | 36 | 0 | 0 | 8 |
8 | 24 | 1 | 0 | 7 |
9 | 52 | 0 | 1 | 8 |
10 | 29 | 1 | 0 | 6 |
11 | 31 | 0 | 1 | 11 |
12 | 40 | 1 | 0 | 8 |
13 | 23 | 0 | 1 | 14 |
14 | 47 | 1 | 0 | 5 |
15 | 34 | 0 | 0 | 6 |
16 | 51 | 1 | 1 | 7 |
17 | 26 | 0 | 0 | 7 |
18 | 39 | 1 | 1 | 10 |
19 | 44 | 0 | 0 | 5 |
20 | 28 | 1 | 1 | 12 |
2.2.1 划分原始数据集
首先,将原始数据随机分为分割集(前10条)和估计集(后10条)。
分割集中,处理组的ID为1,3,5,6,9,平均结果变量为10.4;对照组的ID为2,4,7,8,10,平均结果变量为6.8。因此,分割集总体CATE为:
τ S = 10.4 − 6.8 = 3.6 \tau_S=10.4 - 6.8 = 3.6 τS=10.4−6.8=3.6
2.2.2 确定最佳分割
假设允许两种分割方式:按Age=30划分,或按性别划分。
按Age=30划分:
左叶节点(Age≤30):共4个样本。处理组的ID为1,结果变量为13;对照组的ID为4,8,10,平均结果变量为6.33。CATE为:
τ L = 13 − 6.33 = 6.67 \tau_L=13 - 6.33 = 6.67 τL=13−6.33=6.67右叶节点(Age>30):共6个样本。处理组的ID为3,5,6,9,平均结果变量为9.75;对照组的ID为2和7,平均结果变量为7.5。CATE为:
τ R = 9.75 − 7.5 = 2.25 \tau_R=9.75 - 7.5 = 2.25 τR=9.75−7.5=2.25
对应的 Q Q Q值为:
Q a g e = 4 ⋅ ( 6.67 − 3.6 ) 2 + 6 ⋅ ( 2.25 − 3.6 ) 2 = 48.6 Q_{age}=4·(6.67-3.6)^2 + 6·(2.25-3.6)^2=48.6 Qage=4⋅(6.67−3.6)2+6⋅(2.25−3.6)2=48.6
如下图所示,分割集在Age=30条件下的分割结果:
**按性别分割:**同理计算,得到 Q g e n d e r = 9.7 Q_{gender}=9.7 Qgender=9.7。
由于 Q g e n d e r < Q a g e Q_{gender} < Q_{age} Qgender<Qage,因此Age=30是更优的分割方案。
2.2.3 评估因果效应
切换至估计集,按照Age=30标准分组:
- 左叶节点:共3个样本。处理组的ID为13和20,平均结果变量为13;对照组的ID为17,结果变量为7。CATE为13-7=6。
- 右叶节点:共7个样本。处理组的ID为11,16,18,平均结果变量为9.33;对照组的ID为12,14,15,19,平均结果变量为6。CATE为9.33 - 6 = 3.33。
最终评估结果如下:
分组 | CATE |
---|---|
Age≤30 | 6 |
Age>30 | 3.33 |
估计集在Age=30条件下的分割结果如下图所示:
3 因果森林
因果森林可以理解为由多棵因果树组成的集成模型,其算法原理与随机森林高度相似,具体内容可参考:随机森林原理和性能分析。
多棵树优于单棵树的理论基础是孔多塞陪审团定理。简而言之,多数投票的正确概率高于任何单一模型;当模型数量足够大时,集成模型的准确率将趋近于完美。
基于因果树构建因果森林,常见的两种方式如下:
- Bootstrap采样(Bagging):每棵因果树都在从原始数据集有放回抽样得到的子集(bootstrap sample)上训练。这种方式可以降低单棵树对特定数据点的依赖,提高集成模型的稳健性。
- 特征随机选择(Random Feature Selection):每次分裂节点时,并非使用全部特征,而是从中随机选择一部分作为候选分割变量。这可以避免某些强特征主导所有树的分割,提高模型对高维数据的适应能力,并减少过拟合。
训练完因果森林后,对每个样本的CATE估算方式是:在所有树中分别预测其CATE,然后对这些预测值取平均。
4 代码实例
下面展示如何通过代码使用因果森林模型估计因果效应。
整体代码框架与因果推断 | 元学习方法原理详解和代码实操一致,分为三个主要步骤:
- 构造实例数据:需要注意的是,树模型适用于评估异质性因果效应,因此将synthetic_data中的mode参数由1(线性结果)改为2(非线性结果)。
- 训练算法模型:使用causalml工具包,分别调用CausalTreeRegressor和CausalRandomForestRegressor以实现因果树和因果森林模型。为便于对比,代码中还保留了X-learner和DML模型。
- 评估模型效果:评估指标包括ATE、IDE和AUUC,并对IDE和AUUC进行了可视化展示。
import numpy as np
import pandas as pd
from xgboost import XGBRegressor
from causalml.inference.meta import BaseXRegressor
from causalml.dataset import synthetic_data
from econml.dml import CausalForestDML
from causalml.metrics import auuc_score, plot_gain
import matplotlib.pyplot as plt
from causalml.inference.tree import CausalRandomForestRegressor, CausalTreeRegressor
def plot_sorted_tau_and_preds(df):
# tau真实值和4个模型预测
tau = df['tau_true'].values
preds = {
'X-learner': df['x-learner'].values,
'DML': df['DML'].values,
'CausalForest': df['CausalForest'].values,
'CausalTree': df['CausalTree'].values
}
# 排序索引
idx = np.argsort(tau)
tau_sorted = tau[idx]
preds_sorted = {k: v[idx] for k, v in preds.items()}
x = np.arange(len(tau))
# 统一y轴范围
y_all = [tau_sorted] + [preds_sorted[k] for k in preds]
ymin = min([arr.min() for arr in y_all])
ymax = max([arr.max() for arr in y_all])
# 画4个子图
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
model_names = list(preds.keys())
for i, ax in enumerate(axes.flatten()):
ax.scatter(x, tau_sorted, label='True tau', color='black', s=10, alpha=0.7)
ax.scatter(x, preds_sorted[model_names[i]], label=model_names[i], color='tab:blue', s=10, alpha=0.7)
ax.set_title(f'{model_names[i]}')
ax.set_xlabel('Sample (sorted by tau)')
ax.set_ylabel('ITE')
ax.set_ylim(ymin, ymax)
ax.legend()
plt.tight_layout()
plt.show()
def calc_by_package(X, treatment, y, tau):
# X-learner
learner_x = BaseXRegressor(learner=XGBRegressor())
ate_x = learner_x.fit_predict(X=X, treatment=treatment, y=y)
print('estimated causal effect, by X-learner: {:.04f}'.format(np.mean(ate_x)))
# DML
cf_dml = CausalForestDML(model_t=XGBRegressor(), model_y=XGBRegressor())
cf_dml.fit(y, treatment, X=X)
ate_dml = cf_dml.effect(X)
print('estimated causal effect, by DML: {:.04f}'.format(np.mean(ate_dml)))
# 拟合因果森林
cf_cf = CausalRandomForestRegressor()
cf_cf.fit(X, treatment, y)
ate_cf = cf_cf.predict(X)
print('estimated causal effect, by RandomForest: {:.04f}'.format(np.mean(ate_cf)))
# 拟合因果树
cf_cr = CausalTreeRegressor()
cf_cr.fit(X, treatment, y)
ate_cr = cf_cr.predict(X)
print('estimated causal effect, by CausalTree: {:.04f}'.format(np.mean(ate_cr)))
# 获取叶节点数量和对应的CATE
print("叶节点数量:", cf_cr.tree_.n_leaves)
leaf_ids = cf_cr.apply(X)
unique_leaves = np.unique(leaf_ids)
leaf_cate_dict = {}
for leaf in unique_leaves:
idx = (leaf_ids == leaf)
treat_idx = idx & (treatment == 1)
control_idx = idx & (treatment == 0)
if treat_idx.sum() > 0 and control_idx.sum() > 0:
cate = y[treat_idx].mean() - y[control_idx].mean()
else:
cate = np.nan # 或者跳过该节点
leaf_cate_dict[leaf] = cate
for leaf, cate in leaf_cate_dict.items():
print(f"叶节点 {leaf} 的 CATE: {cate}")
# 合并结果
df = pd.DataFrame({
'y': y,
'treat': treatment,
'x-learner': np.ravel(ate_x),
'DML': np.ravel(ate_dml),
'CausalForest': np.ravel(ate_cf),
'CausalTree': np.ravel(ate_cr)
})
df['tau_true'] = tau
print('true causal effect: {}'.format(np.mean(tau)))
# auuc
auuc = auuc_score(df, outcome_col='y', treatment_col='treat', normalize=True, tmle=False)
print(auuc)
return df
if __name__ == '__main__':
np.random.seed(0)
# y-观测结果;X-样本特征;treatment-处理变量;tau-个体处理效应
y, X, treatment, tau, b, e = synthetic_data(mode=2, p=25)
result_df = calc_by_package(X, treatment, y, tau)
plot_sorted_tau_and_preds(result_df)
# 绘制auuc曲线
plot_gain(
result_df,
outcome_col='y',
treatment_col='treat',
normalize=True,
random_seed=10,
n=100,
figsize=(8, 8)
)
plt.show()
结果解读:
- ATE结果:结果较为意外,因果树的预测更接近真值,但这一结论并不具备普适性。例如,将数据中的p(协变量数量)由25增加到40时,因果森林的ATE指标会更接近真值。
estimated causal effect, by X-learner: 0.7352
estimated causal effect, by DML: 0.5525
estimated causal effect, by RandomForest: 0.6582
estimated causal effect, by CausalTree: 0.6999
true causal effect: 0.7946483837099231
- ITE结果:因果树仅保留7个叶节点,因此ITE仅有7个离散值(0位置附近,有两个值);而因果森林通过多棵树平均效应,使ITE分布更加均衡。
叶节点数量: 7
叶节点 4 的 CATE: -1.2794264565485223
叶节点 5 的 CATE: 0.04997737614570308
叶节点 7 的 CATE: 0.6736109791416229
叶节点 8 的 CATE: 0.12112326235879434
叶节点 10 的 CATE: 1.1193301595067435
叶节点 11 的 CATE: 2.0795258688365204
叶节点 12 的 CATE: 1.5968641890556292
- AUUC结果:因果森林在排序能力上优于因果树。AUUC关注模型在ITE较高个体(图右上部分)上排序的准确性。从ITE分布图可见,因果森林认为ITE较大的样本,其真实值普遍分布在右侧;而因果树预测的最大值则横跨600以上区间。
x-learner 1.281181
DML 1.059788
CausalForest 1.026942
CausalTree 0.941297
tau_true 0.971806
5 总结
正文到此结束,核心内容总结如下:
- 因果森林模型能够有效评估异质性因果效应,适用于因果推断任务中个体化或分群处理效应的估计。
- 因果森林的基础构件是因果树,其核心思想是通过寻找最优分割点,使不同叶节点的CATE(条件平均处理效应)差异最大化,从而揭示处理效应的异质性。
6 相关阅读
Recursive partitioning for heterogeneous causal effects:https://www.pnas.org/doi/10.1073/pnas.1510489113
Estimation and Inference of Heterogeneous Treatment Effects using Random Forests:https://www.tandfonline.com/doi/full/10.1080/01621459.2017.1319839
决策树入门、sklearn实现、原理解读和算法分析:https://mp.weixin.qq.com/s/PbtFMBylahNSKteiBClEZw
随机森林原理和性能分析:https://mp.weixin.qq.com/s/E9izVenKjmp4jCnpFw51rA
因果推断 | 元学习方法原理详解和代码实操:https://mp.weixin.qq.com/s/zA5PU0uXw-ZMOKJPBNyBJg