大家好呀!今天咱们要来聊聊一个超厉害的技术——量化感知训练(Quantization-Aware Training,简称 QAT)
在神经网络的世界里,我们总是想方设法地让模型变得更准确、更高效,毕竟谁不想自己的模型在边缘设备上也能大展身手呢?不过,光靠改变架构、融合多层或者编译模型这些招数,有时候还是不够看的。为了打造又小又准的模型,研究人员们可是想了不少办法,主要有这么三种:
- 模型量化:这个招数的精髓就是把模型的权重从高精度(比如 16 位浮点数)变成低精度(比如 8 位整数),这样一来,模型的内存占用和计算需求就能大大减少啦。
- 模型剪枝:顾名思义,就是把训练好的神经网络里那些不太重要的神经元或者权重给去掉,让模型的结构变得更简单,而且还不怎么影响性能哦。
- 知识蒸馏(也叫“教师-学生训练”):想象一下,有个又大又复杂的模型(教师)特别厉害,它能把一些高级的知识传递给一个更小、更高效的模型(学生)。这个过程就像是教师把自己的“软标签”(也就是对不同类别相似性的高级理解)教给学生,让学生能够更好地泛化,而不是像教师那样只用那种尖锐的 one-hot 编码表示哦。
接下来,咱们就来深入了解一下模型量化到底是个啥,它有哪两种技术,还有怎么用 PyTorch 把它们实现出来。
量化到底是个啥玩意儿?
这张图展示了一个从 FP32 到 INT8 的代表性映射
量化可是优化神经网络的杀手锏之一呢!它的原理说起来也挺简单的,就是把神经网络里的模型数据(包括网络参数和激活值)从高精度的浮点数(通常是 16 位)转换成低精度的表示(通常是 8 位整数)。这么做的好处可多了去了,比如说:
- GPU 们现在可以用那种又快又省成本的 8 位核心(Nvidia GPU 的 Tensor Cores 就是这么干的)来计算卷积和矩阵乘法这些操作啦,这样一来,计算吞吐量就能蹭蹭往上涨咯。
- 有些层是带宽受限的(也就是内存受限),它们大部分时间都在忙着读写数据呢。对于这些层来说,减少计算时间并不能减少它们的总体运行时间。不过,要是减少了带宽需求,那可就大不一样啦!
- 模型的内存占用变小了,这意味着模型需要的存储空间少了,参数更新的体积也小了,缓存利用率还能提高呢,总之好处多多呀。
- 最后,把数据从内存搬到计算单元这个过程可是既耗时又耗能的哦。要是把精度从 16 位降到 8 位,那数据量就能减少一半,这样一来,就能节省不少能耗咯。
当然啦,把高精度数字映射到低精度数字的方法可不止一种,比如零点量化、绝对最大值(absmax)量化等等。不过,咱们今天就不深入这些细节啦。要是你对这些感兴趣,可以去瞅瞅 Hao Wu et al. 和 Amir Gholani et al. 写的那两篇技术论文哦。
量化的两种招式
这张图是我自己画的哦。
量化模型主要有两种方法:
1. 训练后量化(PTQ)
训练后量化(PTQ)是在模型训练完成之后才施展的一种招式。它不需要重新训练模型,而是通过一个小的校准数据集来确定最优的量化参数,从而把模型从高精度转换成低精度。这个过程主要是收集模型激活值的统计信息,然后计算合适的量化参数,尽量减少浮点表示和量化表示之间的差异。
PTQ 可以说是非常节省资源、实现和部署起来也很快呢,特别适合那些没办法重新训练模型的情况。不过,它的缺点也很明显哦,模型的准确性会有所下降,而且需要仔细地校准和调整参数,所以它更适合用来快速原型制作,而不是真正部署到实际场景中。
训练后量化又可以细分为两种小招式:
1.1)动态训练后量化
这个小招式的核心是在推理过程中,根据运行时输入模型的数据分布,实时地调整激活值的范围。
1.2)静态训练后量化
在这个方法里,会多出一个校准步骤。具体来说,就是用一个有代表性的数据集来估计激活值的范围,这个估计过程是在全精度下完成的,目的是尽量减少误差。等到估计完成之后,再把激活值缩放到低精度的数据类型中去。
2. 量化感知训练(QAT)
QAT 可是另一种大招哦!它是在模型训练过程中就模拟量化效果的一种方法。具体来说,就是在训练过程中引入“假量化”操作,这些操作能够模拟低精度对权重和激活值的影响。换句话说,就是带着量化约束来训练模型。模型在训练过程中会用到一种叫做**直通估计器(Straight-Through Estimator,简称 STE)**的技术来计算梯度,这样一来,模型就能更好地适应低精度的环境啦。
QAT 的好处在于,模型在训练过程中就能适应量化带来的噪声,所以最终的量化模型在准确性上通常会比 PTQ 的要好不少。不过,这也意味着 QAT 需要更多的计算资源和时间,因为要重新训练模型,而且实现起来也相对复杂一些。不过,要是你的模型对量化误差比较敏感,那 QAT 绝对是个不错的选择哦!
量化感知训练到底是怎么工作的呢?
这张图展示了 QAT 在 PyTorch 中的工作原理哦。
从前面的介绍里,咱们已经知道 QAT 的好处啦——和 PTQ 不同,QAT 是在训练过程中插入“假量化”模块的。这样一来,模型就能“看到”量化噪声,并且学会补偿这些噪声,最终得到的量化模型在准确性上就能和全精度模型相当接近啦。QAT 的工作流程大致如下:
- 准备阶段:把敏感层(比如卷积层、线性层、激活层等)替换成带有量化模拟功能的包装器。在 PyTorch 里,这可以通过
prepare_qat
或者prepare_qat_fx
来完成哦。 - 训练阶段:在每次前向传播过程中,权重和激活值都会被“假量化”——也就是像在 INT8/INT4 中那样进行四舍五入和限制范围。在反向传播过程中,会用到 STE,让梯度就像量化是恒等函数一样流动。
- 转换阶段:训练完成之后,用
convert
或者convert_fx
把假量化模块换成真正的量化内核。这样一来,模型就准备好进行高效的int8 / int4
推理啦。
假量化的数学原理
这张图展示了简单的量化过程。
咱们先不深入那些复杂的技巧,简单来说,量化过程是这样的:
假设 x_float
是一个实值激活值。均匀仿射量化使用以下公式:
scale = x max − x min q max − q min \text{scale} = \frac{x_{\text{max}} - x_{\text{min}}}{q_{\text{max}} - q_{\text{min}}} scale=qmax−qminxmax−xmin
zeroPt = round ( q min − x min scale ) \text{zeroPt} = \text{round}\left(\frac{q_{\text{min}} - x_{\text{min}}}{\text{scale}}\right) zeroPt=round(scaleqmin−xmin)
x q = clamp ( round ( x float scale ) + zeroPt , q min , q max ) x_q = \text{clamp}\left(\text{round}\left(\frac{x_{\text{float}}}{\text{scale}}\right) + \text{zeroPt}, q_{\text{min}}, q_{\text{max}}\right) xq=clamp(round(scalexfloat)+zeroPt,qmin,qmax)
x deq = ( x q − zeroPt ) × scale x_{\text{deq}} = (x_q - \text{zeroPt}) \times \text{scale} xdeq=(xq−zeroPt)×scale
在 QAT 中,假量化操作是这样的:
x fake = ( round ( x float scale ) + zeroPt − zeroPt ) × scale x_{\text{fake}} = \left(\text{round}\left(\frac{x_{\text{float}}}{\text{scale}}\right) + \text{zeroPt} - \text{zeroPt}\right) \times \text{scale} xfake=(round(scalexfloat)+zeroPt−zeroPt)×scale
所以,x_fake
仍然是浮点数,但它现在被限制在了 int8
张量会占据的相同离散格点上了。
梯度流动——直通估计器
这张图展示了 QAT 假量化操作在训练的前向传播(左边)和反向传播(右边)过程哦,图片来源是 这里。
四舍五入操作是不可微的,所以 PyTorch 用的是:
d L d x float ≈ d L d x fake \frac{dL}{dx_{\text{float}}} \approx \frac{dL}{dx_{\text{fake}}} dxfloatdL≈dxfakedL
在反向传播过程中,直接把假量化模块当作恒等函数来处理梯度,这样优化器就能调整上游的权重,以抵消量化噪声啦。
结果就是,学到的权重会自然地落在整数中心附近,而且调整后的 scale
和 zeroPt
能够最小化总的重构误差哦。
动手实践:用 PyTorch 实现量化
PyTorch 提供了三种不同的量化模式,接下来咱们就来一一看看吧。
1. Eager 模式量化
这是一个还在测试阶段的功能哦。用户需要手动进行融合,并且指定量化和反量化的具体位置。而且,这个模式只支持模块,不支持函数式操作。
下面的代码片段展示了从模型定义到 QAT 准备,再到最终的 int8
转换的每一步哦:
import os, torch, torch.nn as nn, torch.optim as optim
# 1. 定义模型,包含 QuantStub 和 DeQuantStub
class QATCNN(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub() # 量化入口
self.conv1 = nn.Conv2d(1, 16, 3, padding=1) # 第一个卷积层
self.relu1 = nn.ReLU() # 第一个激活层
self.pool = nn.MaxPool2d(2) # 池化层
self.conv2 = nn.Conv2d(16, 32, 3, padding=1) # 第二个卷积层
self.relu2 = nn.ReLU() # 第二个激活层
self.fc = nn.Linear(32*14*14, 10) # 全连接层
self.dequant = torch.quantization.DeQuantStub()# 反量化出口
def forward(self, x):
x = self.quant(x) # 进入量化
x = self.pool(self.relu1(self.conv1(x))) # 第一个卷积块
x = self.relu2(self.conv2(x)) # 第二个卷积块
x = x.flatten(1) # 展平
x = self.fc(x) # 全连接层
return self.dequant(x) # 退出量化
# 2. QAT 准备
model = QATCNN()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # 设置量化配置
torch.quantization.prepare_qat(model, inplace=True) # 准备 QAT
# 3. 简单的训练循环
opt = optim.SGD(model.parameters(), lr=1e-2) # 优化器
crit = nn.CrossEntropyLoss() # 损失函数
for _ in range(3): # 训练 3 个 epoch
inp = torch.randn(16,1,28,28) # 随机输入
tgt = torch.randint(0,10,(16,)) # 随机目标
opt.zero_grad() # 清空梯度
crit(model(inp), tgt).backward() # 反向传播
opt.step() # 更新参数
# 4. 转换为真正的 int8 模型
model.eval() # 切换到评估模式
int8_model = torch.quantization.convert(model) # 转换模型
# 5. 查看存储节省
torch.save(model.state_dict(), "fp32.pth") # 保存 FP32 模型
torch.save(int8_model.state_dict(), "int8.pth") # 保存 INT8 模型
mb = lambda p: os.path.getsize(p)/1e6 # 计算文件大小(MB)
print(f"FP32: {mb('fp32.pth'):.2f} MB vs INT8: {mb('int8.pth'):.2f} MB") # 打印大小对比
预期输出:模型大小大约会减少 4 倍,而且在类似 MNIST 的数据上,准确率的损失不会超过 1% 哦。
为啥能行:torch.quantization.prepare_qat
会递归地为每个符合条件的层包装 FakeQuantize
模块,而默认的 FBGEMM
qconfig 会选择适合服务器/边缘 CPU 的每张量权重观察器和每通道激活观察器。
2. FX 图模式量化
这是一种自动化的量化工作流哦,目前还在维护阶段。它在 Eager 模式量化的基础上进行了升级,支持函数式操作,并且能够自动完成量化过程。不过,用户可能需要对自己的模型进行一些重构,以确保兼容性。
需要注意的是,由于可能存在符号追踪的问题,这个模式可能并不适用于所有模型。这也意味着,你可能需要对 torch.fx
有一定的了解哦。下面是一个使用这个模式的代码示例:
import torch, torchvision.models as models
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization import prepare_qat_fx, convert_fx
model = models.resnet18(weights=None) # 或者使用预训练模型
model.train()
# 一行代码搞定 qconfig 映射
qmap = get_default_qat_qconfig_mapping("fbgemm")
# 图重写
model_prepared = prepare_qat_fx(model, qmap)
# 微调几个 epoch
model_prepared.eval()
int8_resnet = convert_fx(model_prepared)
FX 模式是在图级别进行操作的哦:比如 conv2d
、batch_norm
和 relu
这些操作会被自动融合,从而生成更精简的内核,在 CPU 上的延迟也会更小哦。
3. PyTorch 2 导出量化
如果你打算把导出的程序部署到 C++ 运行时,那 PT2E(PyTorch 2 Export)就是你的不二之选咯。这是 PyTorch 2.1 中推出的一种全新的全图量化工作流,专门为通过 torch.export
捕获的模型设计的。整个流程只需要几行代码就能搞定哦:
import torch
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_qat_pt2e, convert_pt2e)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer, get_symmetric_quantization_config)
class Tiny(nn.Module):
def __init__(self): super().__init__(); self.fc=nn.Linear(8,4)
def forward(self,x): return self.fc(x)
ex_in = (torch.randn(2,8),)
exported = torch.export.export_for_training(Tiny(), ex_in).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
qat_mod = prepare_qat_pt2e(exported, quantizer)
# 微调模型 ...
int8_mod = convert_pt2e(qat_mod)
torch.ao.quantization.move_exported_model_to_eval(int8_mod)
最终生成的图可以直接用于 torch::deploy
,或者提前编译到移动引擎中哦。
加餐福利:大型语言模型 Int4/Int8 混合演示
虽然这不是一个明确的 API,但 torchao
/torchtune
提供了一些用于极致压缩的原型量化器哦:
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
model = llama3(vocab_size=4096, num_layers=16,
num_heads=16, num_kv_heads=4,
embed_dim=2048, max_seq_len=2048).cuda()
qat_quant = Int8DynActInt4WeightQATQuantizer()
model = qat_quant.prepare(model).train()
# 开始微调(类似 Kathy 的微调方式)
optim = torch.optim.AdamW(model.parameters(), 1e-4)
lossf = torch.nn.CrossEntropyLoss()
for _ in range(100):
ids = torch.randint(0,4096,(2,128)).cuda()
label = torch.randint(0,4096,(2,128)).cuda()
loss = lossf(model(ids), label)
optim.zero_grad(); loss.backward(); optim.step()
model_quant = qat_quant.convert(model)
torch.save(model_quant.state_dict(),"llama3_int4int8.pth")
模型的激活值运行在 int8
上,权重运行在 int4
上,这样在单个 A100 GPU 上就能获得超过 2 倍的速度提升,同时内存占用减少约 60%,而且困惑度的下降不到 0.8 pp 哦。
如果你想了解更多关于使用 torchao
和 torchtune
进行 LLM 量化的信息,我强烈推荐你去读一读 PyTorch 的相关博客哦。
最佳实践
为了让模型在量化后能够最大程度地节省存储空间,同时又不损失太多的准确性,下面这些小贴士你可得记好了哦:
- 先用 PTQ 暖身:如果 PTQ 的准确性损失不到 2%,那么通常只需要进行短暂的 QAT 微调(5-10 个 epoch)就足够啦。
- 进行消融分析:通过消融分析来找出哪些层对量化比较敏感哦。如果量化某一层会导致性能大幅下降,那可以考虑保持该层的权重不变。
- 尽早融合操作:把
Conv + BN + ReLU
这些操作尽早融合起来,能够稳定观察器的范围,并且还能显著提高准确性呢。 - 冻结批量归一化统计量:在训练了几轮之后,调用
torch.ao.quantization.disable_observer
,然后冻结批量归一化统计量。这样可以防止范围出现振荡哦。 - 监控直方图:使用
torch.ao.quantization.get_observer_state_dict()
或者 Netron 来找出异常值哦。 - 调整学习率计划:使用较小的学习率(不超过 1e-3)可以避免在 STE 近似值起作用时出现过度调整的情况哦。
- 每通道权重量化:这比每张量量化要好一倍呢,它已经被设置为卷积操作的默认量化方式了。
- 混合精度:如果准确性还是下降了,可以考虑把第一层和最后一层保持在
fp16
,这样会更安全一些哦。 - 硬件检查:对于 x86 架构,使用
FBGEMM
;对于 ARM 架构,使用QNNPACK/XNNPACK
;选择与之匹配的 qconfig 哦。
总结
模型部署往往需要多管齐下的策略哦——打造出一个准确的模型其实只是第一步,真正难的是如何在大规模场景中进行部署。如果你的模型没办法承受 PTQ 带来的准确性损失,那 QAT 就是你的救星啦。不过,你得记住,在成功部署的过程中,需要考虑很多因素哦,包括目标平台以及它支持的操作等等。PyTorch 的成熟 QAT 工具链在这个时候就派上大用场啦,它能让量化任何模型都变得轻而易举——不管是简单的 CNN,还是拥有十亿参数的语言模型呢。