KNN 算法详解
KNN(K-Nearest Neighbors,K 近邻算法)是一种基本的分类和回归算法,其主要特点是基于距离测量进行决策。KNN 算法的核心思想是:给定一个训练数据集,对新的输入实例,找到与该实例最近的 K 个训练实例,依据这些实例的类别决定输入实例的类别。
算法原理
计算距离
对于输入样本,计算它与训练集中的每个样本之间的距离。选择最近的 K 个样本
按照距离从小到大排序,选出距离最近的 K 个样本。投票或平均
- 分类任务:根据 K 个邻居的类别,采用投票法(多数原则)决定输入样本的类别。
- 回归任务:计算 K 个邻居样本值的平均值,作为输入样本的预测值。
距离度量方式
KNN 常用的距离度量包括:
欧几里得距离
d ( x , y ) = ∑ i = 1 n ( x i − y i ) 2 d(x, y) = \sqrt{\sum_{i=1}^{n} (x_i - y_i)^2} d(x,y)=i=1∑n(xi−yi)2曼哈顿距离
d ( x , y ) = ∑ i = 1 n ∣ x i − y i ∣ d(x, y) = \sum_{i=1}^{n} |x_i - y_i| d(x,y)=i=1∑n∣xi−yi∣余弦相似度(适用于文本数据等高维数据)
similarity ( x , y ) = x ⋅ y ∥ x ∥ ∥ y ∥ \text{similarity}(x, y) = \frac{x \cdot y}{\|x\| \|y\|} similarity(x,y)=∥x∥∥y∥x⋅y闵可夫斯基距离(通用形式)
d ( x , y ) = ( ∑ i = 1 n ∣ x i − y i ∣ p ) 1 / p d(x, y) = \left( \sum_{i=1}^{n} |x_i - y_i|^p \right)^{1/p} d(x,y)=(i=1∑n∣xi−yi∣p)1/p
当 (p=2) 时为欧几里得距离,当 (p=1) 时为曼哈顿距离。
关键参数
K 的选择
- (K) 值过小:容易受噪声影响,导致过拟合。
- (K) 值过大:模型可能过于平滑,导致欠拟合。
- 一般通过交叉验证或经验选择合适的 (K) 值。
距离度量
根据数据类型选择适合的距离度量方法。数据归一化
- 如果特征的量纲差异较大,需要进行归一化(如 Min-Max 标准化或 Z-Score 标准化),避免某些特征对距离的影响过大。
优缺点
优点
- 简单直观:实现容易,无需训练过程。
- 无参数模型:适合对非线性分布数据进行分类或回归。
- 适用性广:对多分类问题或回归问题均有效。
缺点
- 计算量大:需要存储所有数据,预测时需要计算与所有样本的距离。
- 对噪声敏感:容易受噪声样本干扰。
- 维数灾难:随着数据维度增加,距离度量的效果会降低。
实现步骤
准备数据
- 收集训练数据集和测试数据集。
数据预处理
- 数据标准化或归一化(如将数值范围归一到 [0,1])。
选择超参数
- 确定 (K) 值和距离度量方法。
预测过程
- 对测试样本计算与所有训练样本的距离。
- 按距离从小到大排序,选择 K 个最近样本。
- 分类任务中进行投票;回归任务中取平均。
评估模型
- 使用准确率(分类)或均方误差(回归)来评估模型性能。
Python 实现示例
import numpy as np
from collections import Counter
def knn_predict(X_train, y_train, X_test, k=3):
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))
y_pred = []
for test_point in X_test:
# 计算距离
distances = [euclidean_distance(test_point, x) for x in X_train]
# 获取最近的 k 个点
k_indices = np.argsort(distances)[:k]
k_nearest_labels = [y_train[i] for i in k_indices]
# 投票
most_common = Counter(k_nearest_labels).most_common(1)
y_pred.append(most_common[0][0])
return np.array(y_pred)
示例数据
X_train = np.array([[1, 2], [2, 3], [3, 4], [5, 6]])
y_train = np.array([0, 0, 1, 1])
X_test = np.array([[1.5, 2.5], [4, 5]])
预测结果
predictions = knn_predict(X_train, y_train, X_test, k=3)
print(predictions)
常见应用
1. 文本分类
• 使用 TF-IDF 或词向量表示文本后,基于余弦相似度进行分类。
2. 图像识别
• 对手写数字分类(如 MNIST 数据集)。
3. 推荐系统
• 基于用户或物品的相似度进行推荐。
总结
KNN 是一种简单但有效的分类和回归方法,适合在数据规模较小且特征较少时使用。在实际应用中,为提高效率,可以结合 KD-Tree 或 Ball-Tree 进行加速。