深度学习中的模型剪枝

发布于:2025-07-08 ⋅ 阅读:(27) ⋅ 点赞:(0)

      模型剪枝(model pruning)是一种模型压缩技术,是指从深度学习神经网络模型中移除不重要的参数,以减小模型规模并实现更高效的模型推理,非常适合在受限环境或实时应用中部署。通常,只剪枝参数的权重(weight),而偏差(bias)保持不变。

      模型剪枝主要有两种方法:训练时剪枝、训练后剪枝

      训练时剪枝(动态剪枝):指将剪枝过程直接集成到神经网络的训练阶段。在训练过程中,模型会以鼓励稀疏性(encourages sparsity)或移除不太重要的连接或神经元的方式进行训练,并将其作为优化过程的一部分。这意味着在训练迭代过程中,剪枝决策与权重更新同时进行。训练时剪枝可以通过诸如L1或L2正则化等正则化方法或将剪枝掩码(pruning masks)纳入优化过程来实现。

      训练后剪枝:指在训练模型完全训练后对其进行剪枝,而不考虑在训练过程中进行剪枝。当模型训练至收敛后,会应用剪枝技术来识别并移除训练模型中不太重要的连接、神经元或整个结构。这通常在训练完成后作为单独的步骤进行。主要分为两种类型:结构化剪枝和非结构化剪枝(structured and unstructured pruning)。非结构化剪枝通常侧重于移除单个模型权重参数,而结构化剪枝则侧重于移除整个权重结构

      (1).非结构化剪枝:是一种更简单、更原始的剪枝方法,但它易于上手,门槛较低。如下图所示:通用方法是根据原始权重本身或其激活函数设置最小阈值,以确定是否需要剪枝单个参数。如果参数未达到阈值,则将其设置为零。由于非结构化剪枝涉及将权重矩阵中的单个权重设置为零,这意味着模型剪枝之前的所有计算都需要执行,因此延迟(latency)改进极小。另一方面,它可以帮助降低模型权重的噪声,从而实现更一致的推理,并有助于减小模型大小,实现无损模型压缩。结构化剪枝在没有上下文信息和自适应的情况下肯定无法使用,而非结构化剪枝通常可以开箱即用,且不会带来太大风险。

      (2).结构化剪枝:是一种更注重架构的(architecturally)剪枝方法。如下图所示:通过移除整组结构化的权重,该方法减少了在模型权重图(model’s weights graph)的前向传播过程中必须进行的计算规模(scale of calculations)。这切实提升了模型推理速度和模型规模(model size)。

      训练后剪枝范围

      (1).局部剪枝:是指在神经网络层级内对单个神经元、连接(connections)或权重进行剪枝。它通常侧重于根据某些标准(例如权重幅值较低(low weight magnitude)、在特定层级中重要性较低或对模型性能贡献极小)移除不太重要的连接或神经元。局部剪枝通常涉及迭代技术,其中权重或连接会根据某些标准逐个或分小组进行剪枝。局部剪枝方法的示例包括权重幅值剪枝、基于单位幅值的剪枝或基于连接敏感度的剪枝。

      (2).全局剪枝:涉及同时剪枝整个神经元、层,甚至模型的大部分内容。它考虑的是整个网络中神经元或层的整体重要性,而不是关注单个层中的特定部分。全局剪枝通常涉及更复杂的技术,这些技术会考虑网络不同部分之间的相互作用和依赖关系。全局剪枝方法的示例包括迭代幅值剪枝(对整个网络的权重同时进行排序和剪枝)、最佳脑损伤算法或最佳脑外科手术算法(optimal brain damage, or optimal brain surgeon algorithms)。

      推理性能:虽然剪枝通常有利于模型压缩和加快推理速度,但过度剪枝可能会对模型造成不利影响。虽然剪枝的目的是将不重要的权重置零,但这些权重可能仍会对模型的决策过程产生轻微影响。较高的剪枝率也可能会无意中剪掉重要的权重,从而导致模型准确率下降。

      何时应该进行模型剪枝:在计算资源受限或效率至关重要的部署场景中,模型剪枝尤其有益。

      模型剪枝的主要优势

      (1).减少模型大小:通过剪枝不必要的权重或神经元,可以显著缩减模型的文件大小,使其更适合部署在存储空间和内存有限的设备上。压缩后的模型文件大小会随着剪枝量的增加而线性减小。这是可以预料的,因为在模型文件压缩后,置零的权重占用的空间几乎可以忽略不计。

      (2).更快的推理速度:经过剪枝的网络执行的计算更少,从而加快推理速度。因为归零权重只是简单的传递,不会增加模型的计算复杂度。

      (3).提高泛化能力:有时,剪枝可以通过删除冗余参数来帮助减少过拟合,从而潜在地提高模型在未知数据上的性能。

      微调:剪枝后,通常会对模型进行微调或重新训练,以恢复性能损失,并确保剪枝后的模型性能与原始模型相当。

      Pytorch中剪枝的使用:torch\nn\utils\prune.py

      (1).剪枝模块:要剪枝一个模块(module),首先从torch.nn.utils.prune中可用的剪枝技术中选择一种(或者通过继承BasePruningMethod实现你自己的剪枝方法)。然后,指定模块以及该模块中要修剪的参数名称。最后,使用所选剪枝技术所需的适当关键字参数,指定修剪参数。

      (2).迭代剪枝:模块中的同一参数可以多次剪枝,每次剪枝调用的效果等于连续应用的多个掩码(masks)的组合。

      (3).全局剪枝:使用torch.nn.utils.prune中的global_unstructured

      注:以上主要内容及截图来自:

      1. https://datature.io/blog

      2. https://docs.pytorch.org/tutorials

      GitHubhttps://github.com/fengbingchun/NN_Test