DAY 37

发布于:2025-05-29 ⋅ 阅读:(20) ⋅ 点赞:(0)

对信贷数据集训练后保存权重,加载权重后继续训练50轮,并采取早停策略

import lightgbm as lgb
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
 
# 创建LightGBM数据集
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
 
# 定义LightGBM参数
params = {
    'objective': 'binary',  # 二分类任务
    'metric': 'auc',       # 评估指标为AUC
    'boosting_type': 'gbdt',  # 使用梯度提升决策树
    'learning_rate': 0.01,    # 学习率
    'num_leaves': 31,         # 树的最大叶子数
    'random_state': 42        # 随机种子
}
 
# 训练初始模型,迭代100轮
initial_model = lgb.train(
    params,
    train_data,
    num_boost_round=100,
    valid_sets=[valid_data],
    callbacks=[lgb.log_evaluation(period=10)]  # 每10轮打印一次评估结果
)
 
# 保存模型权重
initial_model.save_model('initial_lgb_model.txt')
 
# 评估初始模型
y_pred_proba = initial_model.predict(X_test)
y_pred = (y_pred_proba >= 0.5).astype(int)
 
print("初始LightGBM模型性能:")
print("混淆矩阵:\n", confusion_matrix(y_test, y_pred))
print("分类报告:\n", classification_report(y_test, y_pred))
print("测试集AUC:", roc_auc_score(y_test, y_pred_proba))
 
# 可视化混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
            xticklabels=['预测为0', '预测为1'],
            yticklabels=['实际为0', '实际为1'])
plt.title("初始LightGBM混淆矩阵")
plt.xlabel("预测标签")
plt.ylabel("真实标签")
plt.show()
 
# 加载保存的模型权重
loaded_model = lgb.Booster(model_file='initial_lgb_model.txt')
 
# 继续训练50轮,带早停策略
continued_model = lgb.train(
    params,
    train_data,
    num_boost_round=50,
    init_model=loaded_model,  # 从加载的模型继续训练
    valid_sets=[valid_data],
    callbacks=[
        lgb.early_stopping(stopping_rounds=10, verbose=True),  # 10轮无提升则停止
        lgb.log_evaluation(period=10)  # 每10轮打印一次评估结果
    ]
)
 
# 保存最终模型
continued_model.save_model('final_lgb_model.txt')
 
# 评估最终模型
y_pred_proba_final = continued_model.predict(X_test)
y_pred_final = (y_pred_proba_final >= 0.5).astype(int)
 
print("最终LightGBM模型性能(继续训练50轮后):")
print("混淆矩阵:\n", confusion_matrix(y_test, y_pred_final))
print("分类报告:\n", classification_report(y_test, y_pred_final))
print("测试集AUC:", roc_auc_score(y_test, y_pred_proba_final))
 
# 可视化最终混淆矩阵
cm_final = confusion_matrix(y_test, y_pred_final)
plt.figure(figsize=(8, 6))
sns.heatmap(cm_final, annot=True, fmt="d", cmap="Blues", cbar=False,
            xticklabels=['预测为0', '预测为1'],
            yticklabels=['实际为0', '实际为1'])
plt.title("最终LightGBM混淆矩阵")
plt.xlabel("预测标签")
plt.ylabel("真实标签")
plt.show()

@浙大疏锦行