目录
✅ 项目背景
我们使用一个简化的玉米产量数据集(可模拟实际数据),训练一个机器学习模型预测玉米产量,并使用 SHAP 值解释模型的关键影响因素。
📦 所用工具
Python
pandas、scikit-learn
xgboost
shap
📁 数据字段(模拟)
字段 | 含义 |
---|---|
rainfall | 降雨量(mm) |
temperature | 温度(℃) |
soil_nitrogen | 土壤氮含量 |
fertilizer | 施肥量 |
yield | 玉米产量(目标变量) |
🧑💻 代码实现步骤
# 第一步:导入库
import pandas as pd
import numpy as np
import shap
import xgboost as xgb
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
# 第二步:构造或加载数据
data = pd.DataFrame({
'rainfall': np.random.uniform(100, 300, 200),
'temperature': np.random.uniform(15, 30, 200),
'soil_nitrogen': np.random.uniform(0.5, 2.0, 200),
'fertilizer': np.random.uniform(50, 150, 200),
})
# 模拟目标变量
data['yield'] = (
0.05 * data['rainfall'] +
0.1 * data['temperature'] +
0.2 * data['soil_nitrogen'] +
0.03 * data['fertilizer'] +
np.random.normal(0, 2, 200)
)
# 第三步:划分训练集与测试集
X = data.drop('yield', axis=1)
y = data['yield']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 第四步:训练模型
model = xgb.XGBRegressor()
model.fit(X_train, y_train)
# 第五步:SHAP 值解释
explainer = shap.Explainer(model)
shap_values = explainer(X_test)
# 第六步:可视化解释
shap.plots.beeswarm(shap_values)
🎯 解读与启发
使用 SHAP 分析后,发现“soil_nitrogen”与“rainfall”对模型预测影响最大,说明氮含量和降雨量是玉米产量的关键变量。
利用这类可解释性分析,有助于科学家与农业管理者构建可信的AI模型,避免“黑箱模型”带来的误解与风险。
🧠 项目拓展建议
加入 LIME 对比分析;
更换模型为随机森林、LightGBM 等;
用真实遥感+气象数据集训练,提高实用性。