目录
一、混淆矩阵到底是干嘛的?
混淆矩阵就是把预测结果 vs 真实标签交叉对比,然后用一个矩阵(或热力图)显示每种预测结果的数量,让你能精确看到模型在哪些类别上预测对了,哪里搞错了。
二、场景示例1
2.1 场景预设
你在做图像识别任务,分为3类
类别索引 实际含义 0 猫 🐱 1 狗 🐶 2 鸟 🐦
2.2 预测结果的混淆矩阵长什么样?
我们看一个实际例子矩阵(真实 vs 预测):
实际/预测 猫(0) 狗(1) 鸟(2) 猫(0) 50 10 0 狗(1) 5 40 5 鸟(2) 0 8 42
2.3 如何解读?
对角线的值(加粗的):
表示预测完全正确的数量
猫预测成猫 = 50,狗预测成狗 = 40,鸟预测成鸟 = 42
👉 对角线越大越好,说明模型越准确
非对角线的值(混淆错误):
猫预测成狗 = 10 → 猫被误判为狗 🐱→🐶
狗预测成猫 = 5,狗预测成鸟 = 5 → 狗被混淆了
鸟预测成狗 = 8 → 鸟也被误判为狗 🐦→🐶
2.4 例子:猫总被预测成狗怎么办?
你说得对:如果模型总把猫当狗,你会在这个格子看到大值:
[猫行][狗列] = 猫被预测成狗
用上面矩阵就是:
实际是猫,预测是狗 → 值是 10
🔥 说明问题:
模型对“猫”的特征学习不好,容易混淆成“狗”
或者猫和狗在数据中长得太像(比如都是短毛动物)
可以考虑:增加“猫”的数据量、做数据增强、加强“猫狗区分特征”的提取
2.5 行和 vs 列和的意义
每一行的总和 = 某类的真实样本数
比如第一行总和 50+10+0 = 60 → 实际有 60 张猫图
每一列的总和 = 预测成该类的样本数
比如“狗”列总和 10+40+8 = 58 → 有 58 张图被预测成狗
👉 你可以看到哪些类总是“被预测得太多”或“被预测得太少”,从而发现模型是否对某些类有偏向性
三、场景示例2
你做一个情感分类模型,分为:
正面 😃
中性 😐
负面 😠
现在模型结果混淆矩阵是:
实际 \ 预测 | 正面 😃 | 中性 😐 | 负面 😠 |
---|---|---|---|
正面 😃 | 50 | 30 | 20 |
中性 😐 | 10 | 70 | 20 |
负面 😠 | 5 | 10 | 85 |
我们能看到:
总体准确率 = (50+70+85)/300 ≈ 68.3%
正面被混成中性 + 负面:很多用户的高兴情绪没识别出来 → Recall 低
中性预测得还行,但也有不少误判 → 精确率、召回率都一般
负面识别得最好 → 模型对“生气”最敏感
四、混淆矩阵的作用
看模型“预测对了多少、错在哪儿”
就像一份“考试成绩单”,不仅告诉你考了多少分(整体准确率),还告诉你具体哪道题错了,错得有多严重。
作用类型 | 用途说明 |
---|---|
定位错误 | 哪类被误判最多?被谁误判最多? |
分析偏向 | 模型是不是更倾向预测某一类? |
样本不均衡分析 | 某些类预测差是不是因为样本太少? |
指导调优方向 | 哪些类需增强区分度或数据量? |
五、混淆矩阵的进阶作用
帮助你发现“模型的错误类型”和“模型的偏好”
它不只是“对 or 错”的统计,而是:
功能 | 解释 | 例子 |
---|---|---|
🎯 定位误判方向 | 看清哪些类被混淆 | 模型总把“猫”判成“狗” → 猫行狗列的值高 |
⚖️ 揭示类别偏向 | 哪类总是被预测太多或太少 | 所有类都被预测成“狗” → 狗列总和特别大 |
📊 分析样本问题 | 是否某类样本太少→模型表现差 | 鸟的预测全错,实际样本只有 10 个 |
🧠 指导改进方向 | 哪些类该补数据、调模型、增强特征 | 鸟总被混为猫 → 增加鸟的角度图、改模型结构 |
六、混淆矩阵是怎么得出来的?
6.1 举个最简单的例子:二分类任务
比如你在做“垃圾邮件检测”,模型要判断一封邮件是:
正类(1):垃圾邮件
负类(0):正常邮件
你有一个测试集,共有 10 条邮件,模型的预测结果和真实结果如下:
样本编号 | 实际标签 | 模型预测 |
---|---|---|
1 | 1 | 1 |
2 | 0 | 0 |
3 | 1 | 1 |
4 | 1 | 0 ❌(漏掉垃圾邮件) |
5 | 0 | 0 |
6 | 1 | 1 |
7 | 0 | 1 ❌(误判正常为垃圾) |
8 | 0 | 0 |
9 | 1 | 1 |
10 | 0 | 0 |
我们统计一下四个关键数量:
类别 | 定义 | 从上面例子中有哪些? | 数量 |
---|---|---|---|
TP(真正例) | 实际是垃圾,预测也是垃圾 | 样本 1、3、6、9 | 4 |
TN(真负例) | 实际是正常,预测也是正常 | 样本 2、5、8、10 | 4 |
FP(假正例) | 实际是正常,被误判为垃圾 | 样本 7 | 1 |
FN(假负例) | 实际是垃圾,被误判为正常 | 样本 4 | 1 |
6.2 得出:混淆矩阵(2x2)
实际 \ 预测 | 垃圾(1) | 正常(0) |
---|---|---|
垃圾(1) | 4 (TP) | 1 (FN) |
正常(0) | 1 (FP) | 4 (TN) |
七、 多分类混淆矩阵怎么得出?
比如你有 3 类情感标签:
0:正面
1:中性
2:负面
你测试集中有如下 5 条样本:
样本 | 实际标签 | 预测标签 |
---|---|---|
1 | 0 | 0 ✅ |
2 | 1 | 1 ✅ |
3 | 0 | 2 ❌ |
4 | 2 | 1 ❌ |
5 | 1 | 0 ❌ |
我们把实际标签作为“行”,预测标签作为“列”来构造混淆矩阵:
实际\预测 | 0 (正面) | 1 (中性) | 2 (负面) |
---|---|---|---|
0 (正面) | 1 | 0 | 1 |
1 (中性) | 1 | 1 | 0 |
2 (负面) | 0 | 1 | 0 |
解释:
对角线(预测正确)有两个:样本 1 和样本 2
样本 3 是“正面 → 预测为负面”,对应位置是第 0 行第 2 列
样本 4 是“负面 → 预测为中性”,对应位置是第 2 行第 1 列
样本 5 是“中性 → 预测为正面”,对应位置是第 1 行第 0 列
八、代码示例
8.1 安装依赖
pip install scikit-learn seaborn matplotlib
8.2 代码示例1
代码说明
这段代码的作用是:
计算混淆矩阵:也就是模型预测 vs 实际标签的“对错统计表”
y_true = [0, 1, 0, 2, 1] # 真实标签 y_pred = [0, 1, 2, 1, 0] # 模型预测
这两个列表按顺序一一对应,意思是:
样本编号 实际标签(y_true) 模型预测(y_pred) 是否预测正确 1 0 0 ✅ 正确 2 1 1 ✅ 正确 3 0 2 ❌ 错了 4 2 1 ❌ 错了 5 1 0 ❌ 错了 计算输出的混淆矩阵
cm
是:[[1 0 1] [1 1 0] [0 1 0]]
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 0, 2, 1] # 真实标签
y_pred = [0, 1, 2, 1, 0] # 模型预测
cm = confusion_matrix(y_true, y_pred)
print(cm)
【运行结果】
[[1 0 1]
[1 1 0]
[0 1 0]]
输出会是一个 3x3 的二维数组(因为有 3 个分类标签 0/1/2)
【运行结果解读】
这是一个 3×3 的矩阵,表示的是:
横轴:预测值(Predicted)
纵轴:真实值(True)
实际\预测 0 1 2 0 1 0 1 1 1 1 0 2 0 1 0 🧠 解读每个数字的意思:
[0, 0] = 1 → 有 1 个样本是 0,预测对了(预测也是 0)
[0, 2] = 1 → 有 1 个样本是 0,预测错成 2
[1, 0] = 1 → 有 1 个样本是 1,预测错成 0
[1, 1] = 1 → 有 1 个样本是 1,预测对了
[2, 1] = 1 → 有 1 个样本是 2,预测错成 1
✅ 总结一下:
这段代码告诉你:
模型预测对了两条([0→0] 和 [1→1])
预测错了三条:
把 0 判成了 2
把 1 判成了 0
把 2 判成了 1
总体准确率:2 / 5 = 40%
8.3 代码示例2(可视化热力图)
代码说明
这段代码:
构造了一个混淆矩阵(
confusion_matrix(y_true, y_pred)
)
➤ 用于对比真实标签和模型预测,统计“预测对/错”情况。用 Seaborn 的热力图函数
sns.heatmap
可视化混淆矩阵
➤ 把抽象的数字统计图,变成更直观的“颜色格子图”。最终展示图中:
横轴是预测标签(Predicted)
纵轴是真实标签(True)
每个格子的数值表示:“实际是 A,被预测为 B 的样本数量”
【执行代码】
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 0, 2, 1]
y_pred = [0, 1, 2, 1, 0]
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, cmap='Blues', xticklabels=[0,1,2], yticklabels=[0,1,2])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()
【运行结果】
【输出结果解析】
【 输出的混淆矩阵是啥意思?】
给的标签是:
y_true = [0, 1, 0, 2, 1] y_pred = [0, 1, 2, 1, 0]
→ 用这两个列表生成的混淆矩阵是:
实际 \ 预测 0 1 2 0 1 0 1 1 1 1 0 2 0 1 0 解释如下:
[0, 0] = 1:有 1 个样本实际是 0,预测也是 0(预测对了 ✅)
[0, 2] = 1:有 1 个样本实际是 0,预测成了 2(预测错 ❌)
[1, 0] = 1:有 1 个样本实际是 1,预测成了 0(预测错 ❌)
[1, 1] = 1:有 1 个样本实际是 1,预测也是 1(预测对 ✅)
[2, 1] = 1:有 1 个样本实际是 2,预测成了 1(预测错 ❌)
【什么是热力图(Heatmap)?】
热力图是一种颜色编码矩阵数据的图表,常用于:
显示数值大小差异
强调某些数值高的区域
在
sns.heatmap
中:
每个格子的颜色深浅代表数值大小(预测数量多的颜色更深)
annot=True
会把每个格子的数值直接显示出来所以你能一眼看到哪些分类容易被预测错、哪些预测得多或少。
【结果得到了哪些元素?】
你得到了什么? 意义 confusion_matrix()
构造分类正确/错误的统计表 sns.heatmap()
把混淆矩阵“可视化成颜色图” 混淆矩阵热力图 更直观地看到“模型错误在哪”、“预测偏向哪一类”
✅ 总结一句话记住
混淆矩阵 = 模型“预测行为”的全景图,对角线越亮越好,非对角线越高说明“搞混了”。
准确率告诉你“分数”,混淆矩阵告诉你“哪道题错了”,让你有针对性地优化模型表现。