最近邻回归(概念+实例)

发布于:2024-05-01 ⋅ 阅读:(24) ⋅ 点赞:(0)

目录

前言

一、基本概念

1. KNN回归的原理

2. KNN回归的工作原理举例

3. KNN回归的参数

4. KNN回归的优缺点

5. KNN回归的应用场景

二、实例


前言

最近邻回归(K-nearest neighbors regression,简称KNN回归)是一种简单而又直观的非参数回归方法。与其他回归方法不同,KNN回归不需要对数据进行假设,而是直接利用数据中的实例进行预测。在这种方法中,预测结果是由最接近输入实例的K个训练样本的输出值的加权平均得到的。KNN回归在数据量较小,且数据之间的关系较为复杂时表现出色,但在处理大规模数据时可能效率较低。

一、基本概念

1. KNN回归的原理

KNN回归的原理非常简单,它主要包括以下几个步骤:

  1. 计算距离: 对于给定的预测样本,首先计算它与训练集中每个样本之间的距离。通常采用的距离度量包括欧氏距离、曼哈顿距离、闵可夫斯基距离等。

  2. 找出最近邻: 从训练集中选取与预测样本距离最近的K个样本。

  3. 计算预测值: 对于回归问题,预测值通常是这K个最近邻样本输出值的加权平均。

2. KNN回归的工作原理举例

假设我们有一个包含多个特征(如房屋面积、房间数量等)和对应房价的训练数据集。现在有一个新的房屋,我们想要预测它的房价。这时我们可以使用KNN回归:

  1. 对于这个新的房屋,我们计算它与训练集中每个房屋之间的距离。

  2. 找出距离最近的K个房屋。

  3. 根据这K个房屋的房价,计算出新房屋的预测房价。

3. KNN回归的参数

KNN回归中的主要参数是K值和距离度量方法:

  • K值: 它决定了用于预测的最近邻的数量。较小的K值会使模型对噪声敏感,而较大的K值会使模型对数据局部特征的捕捉能力降低。

  • 距离度量方法: 常用的距离度量方法包括欧氏距离、曼哈顿距离和闵可夫斯基距离等。选择合适的距离度量方法对模型的性能具有重要影响。

4. KNN回归的优缺点

优点:

  • 简单直观: KNN回归的原理非常简单直观,易于理解和实现。
  • 无需假设: 与线性回归等参数方法不同,KNN回归不需要对数据的分布做任何假设。
  • 适用性广泛: KNN回归适用于各种类型的回归问题,并且对数据的分布形式没有要求。

缺点:

  • 计算复杂度高: 在预测时需要计算新样本与所有训练样本之间的距离,当训练集很大时,计算复杂度会很高。
  • 存储空间大: KNN回归需要存储整个训练集,当训练集很大时,会占用大量的存储空间。
  • 预测速度慢: 由于需要计算新样本与所有训练样本之间的距离,因此预测速度较慢,特别是在大规模数据集上。

5. KNN回归的应用场景

KNN回归在许多领域都有着广泛的应用,特别是在以下几个方面:

  • 房价预测: 如前文所述的房价预测是KNN回归的典型应用之一。
  • 销量预测: 根据历史销售数据,预测未来某个产品的销量。
  • 金融风险评估: 利用过去的金融数据,预测未来某个投资产品的收益率或风险。
  • 医学诊断: 根据患者的临床特征,预测患者是否患有某种疾病。

二、实例

这段代码创建了一组示例数据,其中包括房屋的面积和价格。然后,我们使用 KNN 回归模型拟合了这些数据,并使用模型预测了一系列房屋面积对应的价格。最后,我们绘制了训练数据和预测结果的图形。

代码:

# 导入所需的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsRegressor

# 创建示例数据
# 假设我们有5个房屋的数据,其中X是房屋的面积,y是房屋的价格
X_train = np.array([[80], [100], [120], [150], [200]])  # 房屋面积(单位:平方米)
y_train = np.array([300000, 350000, 400000, 450000, 500000])  # 房屋价格(单位:人民币)

# 定义 KNN 回归模型,这里我们选择 K=3
knn_regressor = KNeighborsRegressor(n_neighbors=3)

# 使用训练数据拟合模型
knn_regressor.fit(X_train, y_train)

# 生成一些测试数据(房屋面积)
X_test = np.arange(80, 201, 10).reshape(-1, 1)

# 使用模型进行预测
y_pred = knn_regressor.predict(X_test)

# 绘制结果
plt.figure(figsize=(10, 6))
plt.scatter(X_train, y_train, color='darkorange', label='Training data')
plt.plot(X_test, y_pred, color='navy', label='Prediction')
plt.xlabel('House Area (sqm)')
plt.ylabel('House Price (RMB)')
plt.title('KNN Regression: House Price Prediction')
plt.legend()
plt.grid(True)
plt.show()

结果:


网站公告

今日签到

点亮在社区的每一天
去签到