【机器学习&深度学习】混淆矩阵解读

发布于:2025-07-07 ⋅ 阅读:(21) ⋅ 点赞:(0)

目录

一、混淆矩阵到底是干嘛的?

二、场景示例1

2.1 场景预设

2.2 预测结果的混淆矩阵长什么样?

2.3 如何解读?

2.4 例子:猫总被预测成狗怎么办?

三、场景示例2 

四、混淆矩阵的作用

五、混淆矩阵的进阶作用

六、混淆矩阵是怎么得出来的?

6.1 举个最简单的例子:二分类任务

6.2 得出:混淆矩阵(2x2)

七、 多分类混淆矩阵怎么得出?

八、代码示例

8.1 安装依赖

8.2 代码示例1

8.3 代码示例2(可视化热力图)

✅ 总结一句话记住


一、混淆矩阵到底是干嘛的?

混淆矩阵就是把预测结果 vs 真实标签交叉对比,然后用一个矩阵(或热力图)显示每种预测结果的数量,让你能精确看到模型在哪些类别上预测对了,哪里搞错了


二、场景示例1

2.1 场景预设

你在做图像识别任务,分为3类

类别索引 实际含义
0 猫 🐱
1 狗 🐶
2 鸟 🐦

2.2 预测结果的混淆矩阵长什么样?

我们看一个实际例子矩阵(真实 vs 预测)

实际/预测 猫(0) 狗(1) 鸟(2)
猫(0) 50 10 0
狗(1) 5 40 5
鸟(2) 0 8 42

2.3 如何解读?

  • 对角线的值(加粗的)

    • 表示预测完全正确的数量

    • 猫预测成猫 = 50,狗预测成狗 = 40,鸟预测成鸟 = 42

    • 👉 对角线越大越好,说明模型越准确

  • 非对角线的值(混淆错误)

    • 猫预测成狗 = 10 → 猫被误判为狗 🐱→🐶

    • 狗预测成猫 = 5,狗预测成鸟 = 5 → 狗被混淆了

    • 鸟预测成狗 = 8 → 鸟也被误判为狗 🐦→🐶


2.4 例子:猫总被预测成狗怎么办?

 你说得对:如果模型总把猫当狗,你会在这个格子看到大值:

[猫行][狗列] = 猫被预测成狗

用上面矩阵就是:

实际是猫,预测是狗 → 值是 10

🔥 说明问题:

  • 模型对“猫”的特征学习不好,容易混淆成“狗”

  • 或者猫和狗在数据中长得太像(比如都是短毛动物)

  • 可以考虑:增加“猫”的数据量、做数据增强、加强“猫狗区分特征”的提取


2.5 行和 vs 列和的意义

  • 每一行的总和 = 某类的真实样本数

    • 比如第一行总和 50+10+0 = 60 → 实际有 60 张猫图

  • 每一列的总和 = 预测成该类的样本数

    • 比如“狗”列总和 10+40+8 = 58 → 有 58 张图被预测成狗

👉 你可以看到哪些类总是“被预测得太多”或“被预测得太少”,从而发现模型是否对某些类有偏向性


三、场景示例2 

你做一个情感分类模型,分为:

  • 正面 😃

  • 中性 😐

  • 负面 😠

现在模型结果混淆矩阵是:

实际 \ 预测 正面 😃 中性 😐 负面 😠
正面 😃 50 30 20
中性 😐 10 70 20
负面 😠 5 10 85

我们能看到:

  • 总体准确率 = (50+70+85)/300 ≈ 68.3%

  • 正面被混成中性 + 负面:很多用户的高兴情绪没识别出来 → Recall 低

  • 中性预测得还行,但也有不少误判 → 精确率、召回率都一般

  • 负面识别得最好 → 模型对“生气”最敏感


四、混淆矩阵的作用

看模型“预测对了多少、错在哪儿”

就像一份“考试成绩单”,不仅告诉你考了多少分(整体准确率),还告诉你具体哪道题错了,错得有多严重。

作用类型 用途说明
定位错误 哪类被误判最多?被谁误判最多?
分析偏向 模型是不是更倾向预测某一类?
样本不均衡分析 某些类预测差是不是因为样本太少?
指导调优方向 哪些类需增强区分度或数据量?

五、混淆矩阵的进阶作用

帮助你发现“模型的错误类型”和“模型的偏好”

它不只是“对 or 错”的统计,而是:

功能 解释 例子
🎯 定位误判方向 看清哪些类被混淆 模型总把“猫”判成“狗” → 猫行狗列的值高
⚖️ 揭示类别偏向 哪类总是被预测太多或太少 所有类都被预测成“狗” → 狗列总和特别大
📊 分析样本问题 是否某类样本太少→模型表现差 鸟的预测全错,实际样本只有 10 个
🧠 指导改进方向 哪些类该补数据、调模型、增强特征 鸟总被混为猫 → 增加鸟的角度图、改模型结构

六、混淆矩阵是怎么得出来的?

6.1 举个最简单的例子:二分类任务

比如你在做“垃圾邮件检测”,模型要判断一封邮件是:

  • 正类(1):垃圾邮件

  • 负类(0):正常邮件

你有一个测试集,共有 10 条邮件,模型的预测结果和真实结果如下:

样本编号 实际标签 模型预测
1 1 1
2 0 0
3 1 1
4 1 0 ❌(漏掉垃圾邮件)
5 0 0
6 1 1
7 0 1 ❌(误判正常为垃圾)
8 0 0
9 1 1
10 0 0

 我们统计一下四个关键数量:

类别 定义 从上面例子中有哪些? 数量
TP(真正例) 实际是垃圾,预测也是垃圾 样本 1、3、6、9 4
TN(真负例) 实际是正常,预测也是正常 样本 2、5、8、10 4
FP(假正例) 实际是正常,被误判为垃圾 样本 7 1
FN(假负例) 实际是垃圾,被误判为正常 样本 4 1

6.2 得出:混淆矩阵(2x2)

实际 \ 预测 垃圾(1) 正常(0)
垃圾(1) 4 (TP) 1 (FN)
正常(0) 1 (FP) 4 (TN)

七、 多分类混淆矩阵怎么得出?

比如你有 3 类情感标签:

  • 0:正面

  • 1:中性

  • 2:负面

你测试集中有如下 5 条样本:

样本 实际标签 预测标签
1 0 0 ✅
2 1 1 ✅
3 0 2 ❌
4 2 1 ❌
5 1 0 ❌

我们把实际标签作为“行”,预测标签作为“列”来构造混淆矩阵:

实际\预测 0 (正面) 1 (中性) 2 (负面)
0 (正面) 1 0 1
1 (中性) 1 1 0
2 (负面) 0 1 0

解释:

  • 对角线(预测正确)有两个:样本 1 和样本 2

  • 样本 3 是“正面 → 预测为负面”,对应位置是第 0 行第 2 列

  • 样本 4 是“负面 → 预测为中性”,对应位置是第 2 行第 1 列

  • 样本 5 是“中性 → 预测为正面”,对应位置是第 1 行第 0 列


八、代码示例

8.1 安装依赖

pip install scikit-learn seaborn matplotlib

8.2 代码示例1

代码说明


这段代码的作用是:

计算混淆矩阵:也就是模型预测 vs 实际标签的“对错统计表”

y_true = [0, 1, 0, 2, 1]  # 真实标签
y_pred = [0, 1, 2, 1, 0]  # 模型预测

这两个列表按顺序一一对应,意思是:

样本编号 实际标签(y_true) 模型预测(y_pred) 是否预测正确
1 0 0 ✅ 正确
2 1 1 ✅ 正确
3 0 2 ❌ 错了
4 2 1 ❌ 错了
5 1 0 ❌ 错了

计算输出的混淆矩阵 cm 是:

[[1 0 1]
 [1 1 0]
 [0 1 0]]

from sklearn.metrics import confusion_matrix

y_true = [0, 1, 0, 2, 1]  # 真实标签
y_pred = [0, 1, 2, 1, 0]  # 模型预测


cm = confusion_matrix(y_true, y_pred)
print(cm)

【运行结果】

[[1 0 1]
 [1 1 0]
 [0 1 0]]

输出会是一个 3x3 的二维数组(因为有 3 个分类标签 0/1/2)

【运行结果解读】

这是一个 3×3 的矩阵,表示的是:

  • 横轴:预测值(Predicted)

  • 纵轴:真实值(True)


实际\预测 0 1 2
0 1 0 1
1 1 1 0
2 0 1 0

🧠 解读每个数字的意思:

  • [0, 0] = 1 → 有 1 个样本是 0,预测对了(预测也是 0)

  • [0, 2] = 1 → 有 1 个样本是 0,预测错成 2

  • [1, 0] = 1 → 有 1 个样本是 1,预测错成 0

  • [1, 1] = 1 → 有 1 个样本是 1,预测对了

  • [2, 1] = 1 → 有 1 个样本是 2,预测错成 1


✅ 总结一下:

这段代码告诉你:

  • 模型预测对了两条([0→0] 和 [1→1])

  • 预测错了三条:

    • 把 0 判成了 2

    • 把 1 判成了 0

    • 把 2 判成了 1

  • 总体准确率:2 / 5 = 40%


8.3 代码示例2(可视化热力图)

代码说明


这段代码:

  1. 构造了一个混淆矩阵confusion_matrix(y_true, y_pred)
    ➤ 用于对比真实标签和模型预测,统计“预测对/错”情况。

  2. 用 Seaborn 的热力图函数 sns.heatmap 可视化混淆矩阵
    ➤ 把抽象的数字统计图,变成更直观的“颜色格子图”。

  3. 最终展示图中:

    • 横轴是预测标签(Predicted)

    • 纵轴是真实标签(True)

    • 每个格子的数值表示:“实际是 A,被预测为 B 的样本数量”

【执行代码】

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

y_true = [0, 1, 0, 2, 1]
y_pred = [0, 1, 2, 1, 0]

cm = confusion_matrix(y_true, y_pred)

sns.heatmap(cm, annot=True, cmap='Blues', xticklabels=[0,1,2], yticklabels=[0,1,2])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

【运行结果】

【输出结果解析】

【 输出的混淆矩阵是啥意思?】

给的标签是:

y_true = [0, 1, 0, 2, 1]
y_pred = [0, 1, 2, 1, 0]

→ 用这两个列表生成的混淆矩阵是:

实际 \ 预测 0 1 2
0 1 0 1
1 1 1 0
2 0 1 0

解释如下:

  • [0, 0] = 1:有 1 个样本实际是 0,预测也是 0(预测对了 ✅)

  • [0, 2] = 1:有 1 个样本实际是 0,预测成了 2(预测错 ❌)

  • [1, 0] = 1:有 1 个样本实际是 1,预测成了 0(预测错 ❌)

  • [1, 1] = 1:有 1 个样本实际是 1,预测也是 1(预测对 ✅)

  • [2, 1] = 1:有 1 个样本实际是 2,预测成了 1(预测错 ❌)


【什么是热力图(Heatmap)?】

热力图是一种颜色编码矩阵数据的图表,常用于:

  • 显示数值大小差异

  • 强调某些数值高的区域

sns.heatmap 中:

  • 每个格子的颜色深浅代表数值大小(预测数量多的颜色更深)

  • annot=True 会把每个格子的数值直接显示出来

所以你能一眼看到哪些分类容易被预测错、哪些预测得多或少


【结果得到了哪些元素?】

你得到了什么? 意义
confusion_matrix() 构造分类正确/错误的统计表
sns.heatmap() 把混淆矩阵“可视化成颜色图”
混淆矩阵热力图 更直观地看到“模型错误在哪”、“预测偏向哪一类”


✅ 总结一句话记住

混淆矩阵 = 模型“预测行为”的全景图,对角线越亮越好,非对角线越高说明“搞混了”。

准确率告诉你“分数”,混淆矩阵告诉你“哪道题错了”,让你有针对性地优化模型表现。