知识蒸馏基础知识

发布于:2025-02-23 ⋅ 阅读:(90) ⋅ 点赞:(0)

参考笔记:

YOLOv5改进系列(二十五) 知识蒸馏理论与实践_yolov5知识蒸馏-CSDN博客

全网最细图解知识蒸馏(涉及知识点:知识蒸馏实现代码,知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)-CSDN博客

学习视频:【精读AI论文】知识蒸馏_哔哩哔哩_bilibili 


目录

1. 什么是知识蒸馏

2.轻量化网络的方式有哪些

3.为什么要知识蒸馏

4.知识蒸馏的理论依据

5.知识蒸馏分类

5.1 目标蒸馏-Logits方法

5.1.1 Hard-targets、Soft-targets

5.1.2 蒸馏温度T

5.1.3 蒸馏温度T的特点

5.2 特征蒸馏方法

6.知识蒸馏过程

7.图解知识蒸馏


 本文主讲的是目标蒸馏-Logits方法 

1. 什么是知识蒸馏

知识蒸馏

知识蒸馏就是把一个大的教师模型的知识萃取出来,把它浓缩到一个小的学生模型,可以理解为一个大的教师神经网络把他的知识教给小的学生网络,这里有一个知识的迁移过程,从教师网络迁移到了学生网络身上,教师网络一般比较臃肿,所以教师网络把知识教给学生网络,学生网络是一个比较小的网络,这样就可以用学生网络去做一些轻量化网络做的事情

2.轻量化网络的方式有哪些

(1)压缩已训练好的模型:知识蒸馏、权值量化、权重剪枝、通道剪枝、注意力迁移

(2)直接训练轻量化网络:SqueezeNet、MobileNetv1v2v3、MnasNet、ShuffleNet、EfficientNet、EfficientDet

(3)加速卷积运算:im2col + GEMM、Wiongrad、低秩分解

(4)硬件部署:TensorRT、Jetson、TensorFlow-lite、Openvino、FPGA集成电路

3.为什么要知识蒸馏

深度学习在计算机视觉、语音识别、自然语言处理等内的众多领域中均取得了令人难以置信的性能。但是,大多数模型在计算上过于昂贵,无法在移动端或嵌入式设备上运行。因此需要对模型进行压缩,这样小的模型就适用于部署在终端设备上了

Student Model 部署在终端设备上 

(1)提升模型精度

如果对目前的网络模型 A 的精度不是很满意,那么可以先训练一个更高精度的 teacher 模型 B(通常参数量更多,时延更大),然后用这个训练好的 teacher 模型 Bstudent 模型 A 进行知识蒸馏,得到一个更高精度的 A 模型。

(2)降低模型时延,压缩模型参数

如果对目前的网络模型 A 的时延不满意,可以先找到一个时延更低,参数量更小的模型 B ,通常来讲, B 模型精度也会比较低,然后通过训练一个更高精度的 teacher 模型 C 来对这个参数量小的模型 B 进行知识蒸馏,使得该模型 B 的精度接近最原始的模型 A,从而达到降低时延的目的。

(3)标签之间的域迁移

假如使用狗和猫的数据集训练了一个 teacher 模型 A ,使用香蕉和苹果训练了一个 teacher 模型 B ,那么就可以用这两个模型同时蒸馏出一个可以识别狗、猫、香蕉以及苹果的模型,将两个不同域的数据集进行集成和迁移

4.知识蒸馏的理论依据

知识蒸馏使用的是 Teacher—Student 模型,其中 Teacher 是“知识”的输出者, Student 是“知识”的接受者。知识蒸馏的过程分为 3 个阶段:

(1)Teacher 模型训练:简称为 Net-T ,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对 Teacher 模型不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入 X , 其都能输出 Y ,其中 Y 经过 softmax 的映射,输出值是对应类别的概率值
(2)Student 模型训练: 简称为 Net-S ,它是参数量较小、模型结构相对简单的模型。同样的,对于输入 X ,其都能输出 YY 经过 softmax 映射后同样能输出对应类别的概率值

(3)知识蒸馏训练: Net-T 学习能力强,可以将它学到的知识迁移给学习能力相对弱的 Net-S 模型,以此来增强 Net-T 模型的泛化能力。复杂笨重但是效果好的 Net-T 模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的 Net-S 小模型

5.知识蒸馏分类

知识蒸馏是对模型的能力进行迁移,根据迁移的方法不同可以简单分为基于目标蒸馏(也称为 Soft-target 蒸馏或 Logits 方法蒸馏)和基于特征蒸馏的算法两个大的方向

5.1 目标蒸馏-Logits方法

分类问题的共同点是模型最后会有一个 softmax 层,其输出值对应了相应类别的概率值。在知识蒸馏时,由于我们已经有了一个泛化能力较强的 Teacher 模型,我们利用 Teacher 模型来蒸馏训练 Student 模型时,可以直接让 Student 模型去学习 Teacher 模型的泛化能力。一个很直白且高效的迁移泛化能力的方法就是:使用 Tearcher 模型的 softmax 层输出的类别概率来作为 Soft-targets Soft-targets 作为训练 Student 模型的标签

5.1.1 Hard-targets、Soft-targets

传统的模型训练过程:采用 Hard-targets ,类别标签对应的 One-Hot 向量,除了正确类别为 1,其他类别都是 0

知识蒸馏训练过程:用 Teacher Modelclass probabilities 作为 Student ModelSoft-targets


下面看一些例子

假设有 3 个类别,分别是 [马,驴子,汽车]

 对于马这张图片,Hard-targetsSoft-targets 两种形式的类别标签如下:

Hard-targets、Soft-targets 

使用 Hard-targets 来训练网络,对于马这张图片,相当于告诉网络这就是一匹马,不是驴和汽车,并且不是驴和汽车的程度是相等的,因为 Hard-targets 中驴和汽车的标签值都为 0

但通过肉眼其实能观察到马和驴其实是有一点相似的,马更像驴子而更不像汽车。Hard-targets 一个致命的缺陷就是秉持绝对的 “正确” ,对于非正确类别不提供任何额外的信息

使用 Soft-targets 来训练网络,对于马这种图片,相当于告诉网络是马的概率为 0.7 ,是汽车的概率为 0.05 ,说明马和汽车是非常不像的,而是驴的概率为 0.25,说明马和驴是有一点类似的。同理,驴和汽车也是非常不像的。显然 Soft-targets 传递的信息更多

其他例子 


综上:

(1)在使用 Soft-targets 训练时,Student Model 可以很快学习到 Teacher Model 的推理过程

(2)传统的 Hard-targets 的训练方式,所有的负标签都会被平等对待。而 Soft-targets 包含了更多的“知识”和“信息”,像谁,不像谁,有多像,有多不像,特别是非正确类别概率的相对大小(驴和车)。所以我们可以用 Teacher 网络的预测结果作为训练 Student 网络的标签 

5.1.2 蒸馏温度T

如果觉得 Teacher ModelSoft-targets 还不够 Soft ,即想把其他非正确类别的概率也变大,把它们的相对大小充分暴露出来,让学生网络能够有一个强烈的信号,知道这些非正确类别的更多信息,因此提出了蒸馏温度 TT 越大,Soft-targets 就越 Soft

实现方法:在原始的 Softmax 公式加 TT = 1 时,就是常规的 Softmax 操作

计算公式 

 下面我们来看一下 T = 1T = 3 时的对比例子

不同温度的Softmax对比 

(1)T = 1 时,普通的 Softmax 操作, 4 个概率值的两级分化十分严重,猫和狗的预测概率非常小,接近于 0  。而马的预测概率非常高,达到 0.88 。 

(2)T = 3 时,使用蒸馏温度,4 个概率值相比较于 T = 1 的概率值两级分化情况没那么严重,变得更 Soft 了,而且各类别概率值的大小顺序和 T=1 时是一样的

5.1.3 蒸馏温度T的特点

  • 原始的 Softmax 函数是 T = 1 时的特例;T<1 时,概率分布比原始更“陡峭”,即当 0<T<1 时,Softmax 的输出值会接近与 Hard-targetsT>1 时,概率分布比原始更“平缓”
  • 随着 T 的增加,Softmax 的输出分布会越来越平缓。温度越高,Softmax 上各个值的分布就越平均,极端情况的 Softmax 值是平均分布的,如下图 T = 100

  • 温度的高低改变的是 Student 模型训练过程中对负标签的关注程度。温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;温度较高时,负标签相关的值会相对增大, Student 模型会相对更多地关注到负标签
  • 针对较困难的分类或检测任务, Temperature 通常取 1 ,确保 Teacher Model 中正确预测的贡献
  • 温度的选取需要进行实际实验的比较,本质上就是如下两种情况作取舍:
    • 当想从负标签中多学到一些信息的时候,温度应调高一些
    • 当想减少负标签的干扰的时候,温度应调低一些

总的来说,温度的选择和 Student 模型的大小有关,Student 模型参数量比较小时,相对比较低的温度就可以了。因为参数量小的模型不能学到所有 Teacher 模型的知识,所以可以适当忽略掉 

5.2 特征蒸馏方法

另外一种知识蒸馏思路是特征蒸馏方法,它不像 Logits 方法那样 Student 只学习 TeacherLogits 这种结果知识,而是学习 Teacher 网络结构中的中间层特征

最早采用这种模式的工作来自于论文《FITNETS:Hints for Thin Deep Nets》,它强迫 Student 某些中间层的网络响应,去逼近 Teacher 对应的中间层的网络响应。这种情况下, Teacher 中间特征层的响应,就是传递给 Student 的知识。在此之后,提出了各种新方法,但是大致思路还是这个思路,本质是 Teacher 将特征级知识迁移给 Student

6.知识蒸馏过程

知识蒸馏过程

损失函数计算 

上图中的 distillation lossstudent loss 也可以称为 soft losshard loss,两者都可以采用传统的交叉熵损失函数,交叉熵损失函数的计算公式如下:

交叉熵损失函数计算公式 

下图是损失函数计算的例子

图解损失函数计算过程 

实际的损失函数设计中,会给 soft losshard loss 增加一个权重系数,构成完成的损失函数,即:

Total loss 设计为 Soft LossHard Loss 所对应的交叉熵的加权和,其中 \color{red}\alpha 越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小 Soft Loss 的比重,让真实标签帮助鉴别困难样本

7.图解知识蒸馏


网站公告

今日签到

点亮在社区的每一天
去签到