【机器学习】非参数贝叶斯回归方法 GPR

发布于:2025-06-26 ⋅ 阅读:(22) ⋅ 点赞:(0)

Gaussian Process Regression (GPR) 是一种强大的 非参数贝叶斯回归方法,适用于拟合非线性关系,并能提供预测的不确定性。

与传统的线性回归模型不同,GPR 能够通过指定的核函数捕捉复杂的非线性关系,并提供不确定性的估计。

下面将详细介绍 GPR 的原理、实现步骤,并附上完整的 Python 实现示例,包括数据生成过程。

🎯 一、GPR 方法原理

1.1 核心思想

GPR 假设目标函数是一个 高斯过程(Gaussian Process) 的样本。高斯过程是定义在输入空间上的 随机函数集合,任意有限个点上的函数值服从联合高斯分布:
在这里插入图片描述

1.2 回归模型设定

假设我们有训练数据:
在这里插入图片描述

1.3 后验推断

构建联合高斯分布:
在这里插入图片描述

1.4 常见核函数

核函数是 GPR 的核心,它决定了模型的平滑度、周期性等特性。选择合适的核函数可以显著提高模型的性能。常见的核函数包括:

  • RBF (Gaussian) Kernel:适用于平滑且连续的函数建模。
    在这里插入图片描述
  • Matern Kernel
  • Polynomial Kernel
  • Dot Product Kernel
  • 线性核:适用于线性关系建模。

核函数的形式和参数需要根据具体问题进行选择和调整。

🛠️ 二、Python 实现 GPR(含数据生成)

2.1 使用 scikit-learn 实现 GPR(基础案例)

核函数设计:RBF 核 + 白噪声 WhiteKernel 是常见组合。
参数优化:n_restarts_optimizer 用于多次尝试最优化超参数。
不确定性估计:GPR 不仅给出预测值,还提供置信区间(方差)。

在这里插入图片描述

完整Python代码如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ConstantKernel as C
import matplotlib as mpl

# 设置字体
mpl.rcParams['font.family'] = 'Times New Roman'


# 1. 生成训练数据
np.random.seed(42)
X_train = np.linspace(0, 5, 20).reshape(-1, 1)
y_train = np.sin(X_train).ravel() + np.random.normal(0, 0.1, X_train.shape[0])

# 2. 构造核函数
kernel = C(1.0) * RBF(length_scale=1.0) + WhiteKernel(noise_level=0.1)

# 3. 初始化并训练 GPR 模型
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10)
gpr.fit(X_train, y_train)

# 4. 测试点预测
X_test = np.linspace(0, 6, 100).reshape(-1, 1)
y_pred, y_std = gpr.predict(X_test, return_std=True)

# 5. 可视化结果
plt.figure(figsize=(10, 6))
plt.plot(X_train, y_train, 'ro', label="Training Data")
plt.plot(X_test, y_pred, 'b-', label="Mean Prediction")
plt.fill_between(X_test.ravel(),
                 y_pred - 1.96 * y_std,
                 y_pred + 1.96 * y_std,
                 alpha=0.3, color='blue', label="95% Confidence Interval")

plt.title("Gaussian Process Regression", fontsize=16,fontweight='bold')
plt.xlabel("x", fontsize=16,fontweight='bold')
plt.ylabel("f(x)", fontsize=16,fontweight='bold')

plt.legend()
plt.grid(True)
plt.show()

参考


网站公告

今日签到

点亮在社区的每一天
去签到