Python Day11

发布于:2025-07-14 ⋅ 阅读:(21) ⋅ 点赞:(0)

@浙大疏锦行 Python Day11

内容:

  • 参数:手动设置(超参数),模型学习(内参数)
  • 模型 = 算法 + 参数
  • 寻找参数框架:网格搜索(爆搜),随机搜索,贝叶斯搜索(优化随机搜索)
  • `time`库计时
import time

start_time = time.time()
# proc...
end_time = time.time()

print(f"time is {end_time - start_time}")

 代码:

import pandas as pd  # 用于数据处理和分析,可处理表格数据。
import numpy as np  # 用于数值计算,提供了高效的数组操作。
import matplotlib.pyplot as plt  # 用于绘制各种类型的图表
import seaborn as sns  # 基于matplotlib的高级绘图库,能绘制更美观的统计图形。
from sklearn.svm import SVC #支持向量机分类器
from sklearn.neighbors import KNeighborsClassifier #K近邻分类器
from sklearn.linear_model import LogisticRegression #逻辑回归分类器
import xgboost as xgb #XGBoost分类器
import lightgbm as lgb #LightGBM分类器
from sklearn.ensemble import RandomForestClassifier #随机森林分类器
from catboost import CatBoostClassifier #CatBoost分类器
from sklearn.tree import DecisionTreeClassifier #决策树分类器
from sklearn.naive_bayes import GaussianNB #高斯朴素贝叶斯分类器
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score # 用于评估分类器性能的指标
from sklearn.metrics import classification_report, confusion_matrix #用于生成分类报告和混淆矩阵
import warnings #用于忽略警告信息
warnings.filterwarnings("ignore") # 忽略所有警告信息
from sklearn.model_selection import train_test_split

# 设置中文字体(解决中文显示问题)
plt.rcParams['font.sans-serif'] = ['SimHei']  # Windows系统常用黑体字体
plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号


data = pd.read_csv("./data/heart.csv")

# 这里不需要处理离散值以及缺失值
X = data.drop(['target'], axis=1)
y = data['target']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print("----------默认参数的随机森林-------------")
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print(f"预测精度为:{accuracy_score(y_pred, y_test)}")
print("---------默认参数的随机森林结束------------")
print("----------网格搜索的随机森林-------------")
from sklearn.model_selection import GridSearchCV
import time
# 定义要搜索的参数网格
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 10, 20, 30],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}
grid_search = GridSearchCV(estimator=RandomForestClassifier(random_state=42), # 随机森林分类器
                           param_grid=param_grid, # 参数网格
                           cv=5, # 5折交叉验证
                           n_jobs=-1, # 使用所有可用的CPU核心进行并行计算
                           scoring='accuracy') # 使用准确率作为评分标准
start_time = time.time()
grid_search.fit(X_train, y_train)
end_time = time.time()
print(f"网格搜索耗时: {end_time - start_time:.4f} 秒")
print("最佳参数: ", grid_search.best_params_) #best_params_属性返回最佳参数组合
# 使用最佳参数的模型进行预测
best_model = grid_search.best_estimator_ # 获取最佳模型
best_pred = best_model.predict(X_test) # 在测试集上进行预测
print("\n网格搜索优化后的随机森林 在测试集上的分类报告:")
print(classification_report(y_test, best_pred))
print("网格搜索优化后的随机森林 在测试集上的混淆矩阵:")
print(confusion_matrix(y_test, best_pred))
print("-------------网格搜索随机森林结束-----------------")
print("-------------贝叶斯参数随机森林-------------------")
from skopt import BayesSearchCV
from skopt.space import Integer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import time

# 定义要搜索的参数空间
search_space = {
    'n_estimators': Integer(50, 200),
    'max_depth': Integer(10, 30),
    'min_samples_split': Integer(2, 10),
    'min_samples_leaf': Integer(1, 4)
}

# 创建贝叶斯优化搜索对象
bayes_search = BayesSearchCV(
    estimator=RandomForestClassifier(random_state=42),
    search_spaces=search_space,
    n_iter=32,  # 迭代次数,可根据需要调整
    cv=5, # 5折交叉验证,这个参数是必须的,不能设置为1,否则就是在训练集上做预测了
    n_jobs=-1,
    scoring='accuracy'
)

start_time = time.time()
# 在训练集上进行贝叶斯优化搜索
bayes_search.fit(X_train, y_train)
end_time = time.time()

print(f"贝叶斯优化耗时: {end_time - start_time:.4f} 秒")
print("最佳参数: ", bayes_search.best_params_)

# 使用最佳参数的模型进行预测
best_model = bayes_search.best_estimator_
best_pred = best_model.predict(X_test)

print("\n贝叶斯优化后的随机森林 在测试集上的分类报告:")
print(classification_report(y_test, best_pred))
print("贝叶斯优化后的随机森林 在测试集上的混淆矩阵:")
print(confusion_matrix(y_test, best_pred))


print("-------------贝叶斯结束---------------")


网站公告

今日签到

点亮在社区的每一天
去签到