K-Means 聚类算法通俗解释
什么是 K-Means 聚类?
K-Means 是一种 "自动分类" 的算法,就像你整理衣柜时会把衣服分成 "上衣"、"裤子"、"裙子" 几类一样,它能帮你把数据按照相似性分成 K 个组(K 是你设定的数字)。
打个比方理解
想象你是一位老师,手里有 30 个学生的成绩单(包含数学、语文、英语分数),你想把他们分成 3 个学习水平相近的小组:
- 先随便选 3 个学生作为 "临时组长"(初始化质心)
- 让其他学生根据自己的分数和组长的相似度,选择加入哪个组(分配簇)
- 每个组重新选一个新组长(成绩最接近组内平均的学生)(更新质心)
- 重复第 2-3 步,直到组长不再变化(迭代优化)
我们的案例做了什么?
我们用了一个包含 "收入"、"年龄"、"学历" 的数据表,想把人群分成几类:
准备工作:
- 把 "学历" 文字(小学、初中等)变成数字(1-5)方便计算
- 标准化数据:比如收入数值很大(几千),年龄较小(几十),需要转换成同一尺度
聚类过程:
- 我们选择分成 2 类(k=2)
- 用 k-means++ 方法选择初始 "中心"(比随机选更合理)
- 计算每个数据点到两个中心的距离(用的是欧氏距离,类似坐标系中两点距离)
- 把每个点分到最近的中心所在的组
- 重新计算每个组的新中心(组内所有点的平均值)
- 重复上面两步,直到中心位置不再变化
结果展示:
- 用三维图展示分类结果(x 轴收入,y 轴年龄,z 轴学历)
- 不同颜色代表不同的簇(组)
- 红色 "X" 代表每个簇的中心
代码演示
一步一步学操作
准备工具
就像做饭需要锅碗瓢盆,我们需要这些 Python 库:# 导入需要的库 import numpy as np # 用于数值计算 import pandas as pd # 用于处理表格数据 from sklearn.cluster import KMeans # 用于K-Means聚类 import matplotlib.pyplot as plt # 用于画图 from sklearn.preprocessing import StandardScaler # 用于数据标准化 from mpl_toolkits.mplot3d import Axes3D # 用于画3D图
导入数据
我们的数据长这样:# 导入数据 df = pd.read_excel('https://labfile.oss.aliyuncs.com/courses/40611/%E8%81%9A%E7%B1%BB%E5%88%86%E6%9E%90.xlsx') # 查看前5行数据 df.head()
运行后会看到:
收入 年龄 学历 6668 57 研究生 14658 60 研究生 6493 17 初中 12458 26 小学 12179 24 小学 数据处理
学历是文字(小学、初中等),我们把它转成数字:
# 把学历文字转换成数字 df['学历'] = df['学历'].map({ "小学": 1, "初中": 2, "高中": 3, "大学": 4, "研究生": 5 })
标准化:因为收入是几千上万,年龄是几十,数字大小差别大,需要统一尺度:
# 数据标准化 scaler = StandardScaler() df_scaled = scaler.fit_transform(df)
K-Means 核心步骤
选 K 值并初始化:这里我们选 2,意思是想分成 2 类
# 设置聚类数为2 k = 2 # 初始化K-Means模型,使用k-means++方法选择初始中心点 kmeans = KMeans(n_clusters=k, init='k-means++')
执行聚类:
# 对标准化后的数据进行聚类 kmeans.fit(df_scaled) # 获取每个数据点的聚类标签(0或1) labels = kmeans.labels_ print(labels) # 会输出类似 [0, 0, 1, ..., 0, 0, 1] 的结果
查看聚类中心:
# 获取每个簇的中心点 centroids = kmeans.cluster_centers_ print(centroids) # 输出两个中心点的坐标
可视化结果
我们可以画出三维图直观展示聚类结果:# 创建3D图形 fig = plt.figure(figsize=(10, 8)) ax = fig.add_subplot(111, projection='3d') # 绘制数据点,不同颜色表示不同簇 scatter = ax.scatter(df_scaled[:, 0], df_scaled[:, 1], df_scaled[:, 2], c=labels, cmap='viridis') # 绘制质心(用红色X标记) ax.scatter(centroids[:, 0], centroids[:, 1], centroids[:, 2], s=300, c='red', marker='X') # 添加标题和坐标轴标签 ax.set_xlabel('收入') ax.set_ylabel('年龄') ax.set_zlabel('学历') # 调整视角 ax.view_init(elev=30, azim=70) # 添加图例 legend1 = ax.legend(*scatter.legend_elements(), title="类别") ax.add_artist(legend1) # 显示图形 plt.show()
你能从结果中看到什么?
- 图中明显分成了两群点,说明数据中确实存在两种不同特征的人群
- 红色中心代表了每类人群的 "典型特征"(平均收入、平均年龄、平均学历)
- 同一簇内的人,在收入、年龄、学历上更相似
为什么要用这个算法?
- 不需要提前知道分类标准(无监督学习)
- 速度快,适合处理大量数据
- 结果直观,容易理解
简单总结步骤
- 确定要分成几类(K 值)
- 选初始中心
- 按距离分组
- 重新算中心
- 重复 3-4 直到稳定
- 看结果,分析每组的特点