掌握Pytorch模型 压缩 裁剪与量化

发布于:2023-01-20 ⋅ 阅读:(312) ⋅ 点赞:(0)

在深度学习模型的搭建和部署中,我们需要考虑到模型的权重个数、模型权重大小、模型推理速度和计算量。本文将分享在Pytorch中进行模型压缩、裁剪和量化的教程。

权重压缩

模型在训练时使用的模型权重类型为float32,而在模型部署时则不需要高的数据精度。可以将类型转换为float16进行保存,这样可以降低45%左右的权重大小。

  • 步骤1:训练并保存模型
import timm
model = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))
  • 步骤2:转换数据类型并存储
params = torch.load('model_mobilevit_xxs.pth') # float32
for key in params.keys():
    params[key] = params[key].half() # float16

torch.save(params, 'model_mobilevit_xxs_half.pth')

权重裁剪

在模型训练完成后可以考虑对冗余的权重进行裁剪,有以下几种裁剪方法:

  • 按照比例随机裁剪
  • 按照权重大小裁剪

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

使用的案例代码如下:

import torch.nn.utils.prune as prune
import numpy as np

model = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))

# 选中需要裁剪的层
module = model.head.fc

# random_unstructured裁剪
prune.random_unstructured(module, name="weight", amount=0.3)

# l1_unstructured裁剪
prune.l1_unstructured(module, name="weight", amount=0.3)

# ln_structured裁剪
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

在使用权重裁剪需要注意:

  • 权重裁剪并不会改变模型的权重大小,只是增加了稀疏性;
  • 权重裁剪并不会减少模型的预测速度,只是减少了计算量;
  • 权重裁剪的参数比例会对模型精度有影响,需要测试和验证;

权重量化

32-bit的乘加变成了8-bit的乘加,模型权重大小减少,对内存的要求降低了。

https://pytorch.org/docs/stable/quantization.html

Eager Mode Quantization

import torch

# define a floating point model
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.fc1 = torch.nn.Linear(100, 40)
        self.fc2 = torch.nn.Linear(1000, 400)

    def forward(self, x):
        x = self.fc1(x)
        return x

# create a model instance
model_fp32 = M()
torch.save(model_fp32.state_dict(), 'tmp_float32.pth')

# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)
torch.save(model_int8.state_dict(), 'tmp_int8.pth')

Post Training Static Quantization

import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 100, 1)
        self.relu = torch.nn.ReLU()
        self.fc = torch.nn.Linear(100, 10)
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()
torch.save(model_fp32.state_dict(), 'tmp_float32.pth')

model_fp32.eval()

model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)

input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

model_int8 = torch.quantization.convert(model_fp32_prepared)
res = model_int8(input_fp32)
torch.save(model_int8.state_dict(), 'tmp_int8.pth')

Pytorch暂时的量化操作还不是很完善,可能存在只能在CPU上运行,且速度变慢的情况。如果有量化需求,推荐使用tensorrt和GPU一起使用。

本文含有隐藏内容,请 开通VIP 后查看