对信贷数据集训练后保存权重,加载权重后继续训练50轮,并采取早停策略
import lightgbm as lgb
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
# 创建LightGBM数据集
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
# 定义LightGBM参数
params = {
'objective': 'binary', # 二分类任务
'metric': 'auc', # 评估指标为AUC
'boosting_type': 'gbdt', # 使用梯度提升决策树
'learning_rate': 0.01, # 学习率
'num_leaves': 31, # 树的最大叶子数
'random_state': 42 # 随机种子
}
# 训练初始模型,迭代100轮
initial_model = lgb.train(
params,
train_data,
num_boost_round=100,
valid_sets=[valid_data],
callbacks=[lgb.log_evaluation(period=10)] # 每10轮打印一次评估结果
)
# 保存模型权重
initial_model.save_model('initial_lgb_model.txt')
# 评估初始模型
y_pred_proba = initial_model.predict(X_test)
y_pred = (y_pred_proba >= 0.5).astype(int)
print("初始LightGBM模型性能:")
print("混淆矩阵:\n", confusion_matrix(y_test, y_pred))
print("分类报告:\n", classification_report(y_test, y_pred))
print("测试集AUC:", roc_auc_score(y_test, y_pred_proba))
# 可视化混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
xticklabels=['预测为0', '预测为1'],
yticklabels=['实际为0', '实际为1'])
plt.title("初始LightGBM混淆矩阵")
plt.xlabel("预测标签")
plt.ylabel("真实标签")
plt.show()
# 加载保存的模型权重
loaded_model = lgb.Booster(model_file='initial_lgb_model.txt')
# 继续训练50轮,带早停策略
continued_model = lgb.train(
params,
train_data,
num_boost_round=50,
init_model=loaded_model, # 从加载的模型继续训练
valid_sets=[valid_data],
callbacks=[
lgb.early_stopping(stopping_rounds=10, verbose=True), # 10轮无提升则停止
lgb.log_evaluation(period=10) # 每10轮打印一次评估结果
]
)
# 保存最终模型
continued_model.save_model('final_lgb_model.txt')
# 评估最终模型
y_pred_proba_final = continued_model.predict(X_test)
y_pred_final = (y_pred_proba_final >= 0.5).astype(int)
print("最终LightGBM模型性能(继续训练50轮后):")
print("混淆矩阵:\n", confusion_matrix(y_test, y_pred_final))
print("分类报告:\n", classification_report(y_test, y_pred_final))
print("测试集AUC:", roc_auc_score(y_test, y_pred_proba_final))
# 可视化最终混淆矩阵
cm_final = confusion_matrix(y_test, y_pred_final)
plt.figure(figsize=(8, 6))
sns.heatmap(cm_final, annot=True, fmt="d", cmap="Blues", cbar=False,
xticklabels=['预测为0', '预测为1'],
yticklabels=['实际为0', '实际为1'])
plt.title("最终LightGBM混淆矩阵")
plt.xlabel("预测标签")
plt.ylabel("真实标签")
plt.show()