PyTorch实战(6)——模型微调详解

发布于:2025-09-04 ⋅ 阅读:(21) ⋅ 点赞:(0)

0. 前言

在本节中,我们将首先简要了解 AlexNet 的架构以及如何使用 PyTorch 构建一个 AlexNet 模型。然后,我们将探索 PyTorch 的预训练 CNN 模型库,最后使用一个预训练的 AlexNet 模型进行微调,应用于图像分类任务,并进行预测。

1. AlexNet

AlexNetLeNet 的继任者,其架构上进行了增量式的改进,例如使用了 8 层( 5 层卷积层和 3 层全连接层),模型参数从 60000 增加到 6000 万,并且使用了 MaxPool 代替 AvgPool。此外,AlexNet 的训练和测试使用了一个更大的数据集——ImageNet,其大小超过 100 GB,而 LeNet 则使用了 MNIST 数据集。AlexNet 的出现真正革新了卷积神经网络 (Convolutional Neural Network, CNN),使其成为在图像相关任务中,比其他经典机器学习模型(如支持向量机 SVM)更强大的模型。下图显示了 AlexNet 的架构:

AlexNet

如我们所见,AlexNet 的架构延续了 LeNet 的常见模式,采用了顺序堆叠的卷积层,并在输出端之前使用了一系列全连接层。PyTorch 使得将这样的模型架构转化为实际代码变得非常容易。使用 PyTorch 实现 AlexNet 架构:

class AlexNet(nn.Module):
    def __init__(self, number_of_classes=1000):
        super(AlexNet, self).__init__()
        self.feats = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=192, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.clf = nn.Linear(in_features=256, out_features=num_classes),

    def forward(self, inp):
        op = self.feats(inp)
        op = op.view(op.size(0), -1)
        op = self.clf(op)
        return op

其中 __init__ 函数包含了整个分层结构的初始化,包括卷积层、池化层和全连接层,并且使用了 ReLU 激活函数。forward 函数则是将数据点 x 传入已初始化的网络中。请注意,forward 方法的第二行已经执行了展平操作,因此我们不需要像 LeNet 那样单独定义该函数。
除了自己初始化模型架构并进行训练之外,PyTorch 通过其 torchvision 包提供了一个 models 子包,里面包含了用于解决不同任务的 CNN 模型定义,例如图像分类、语义分割、目标检测等,常用的图像分类任务模型包括:

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
  • Inception v3
  • GoogLeNet
  • ShuffleNet v2
  • MobileNet v2
  • ResNeXt
  • Wide ResNet
  • MnasNet
  • EfficientNet

在下一小节中,我们将使用一个预训练的 AlexNet 模型作为示例,展示如何使用 PyTorch 对其进行微调。

2. 使用 PyTorch 微调 AlexNet

接下来,我们使用 PyTorch 提供的预训练 CNN 模型库,首先下载并转换数据集。在本节中,我们将使用一个小型的蜜蜂和蚂蚁的图像数据集。数据集包含 240 张训练图像和 150 张验证图像,两个类别(蜜蜂和蚂蚁)的图像数量相等。从 Kaggle 下载数据集,并将解压后的文件存储在 hymenoptera_data 目录中。

(1) 导入所需库,加载并转换数据集:

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms

ddir = 'hymenoptera_data'

data_transformers = {
    'train': transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(), 
                                    transforms.Normalize([0.490, 0.449, 0.411], [0.231, 0.221, 0.230])]),
    'val': transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), 
                                      transforms.ToTensor(), 
                                      transforms.Normalize([0.490, 0.449, 0.411], [0.231, 0.221, 0.230])])}

img_data = {k: datasets.ImageFolder(os.path.join(ddir, k), data_transformers[k]) for k in ['train', 'val']}
dloaders = {k: torch.utils.data.DataLoader(img_data[k], batch_size=8, shuffle=True) 
            for k in ['train', 'val']}
dset_sizes = {x: len(img_data[x]) for x in ['train', 'val']}
classes = img_data['train'].classes
print(classes)
dvc = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(2) 可视化训练数据集的样本图像:

def imageshow(img, text=None):
    img = img.numpy().transpose((1, 2, 0))
    avg = np.array([0.490, 0.449, 0.411])
    stddev = np.array([0.231, 0.221, 0.230])
    img = stddev * img + avg
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    if text is not None:
        plt.title(text)

# Generate one train dataset batch
imgs, cls = next(iter(dloaders['train']))

# Generate a grid from batch
grid = torchvision.utils.make_grid(imgs)

imageshow(grid, text=[classes[c] for c in cls])

使用 NumPynp.clip() 方法,以确保图像像素值限制在 01 之间,从而使可视化更加清晰。输出如下所示:

样本可视化

(3) 定义微调流程,这本质上是在预训练模型上进行的训练过程:

def finetune_model(pretrained_model, loss_func, optim, epochs=10):
    start = time.time()

    model_weights = copy.deepcopy(pretrained_model.state_dict())
    accuracy = 0.0

    for e in range(epochs):
        print(f'Epoch number {e}/{epochs - 1}')
        print('=' * 20)

        # for each epoch we run through the training and validation set
        for dset in ['train', 'val']:
            if dset == 'train':
                pretrained_model.train()  # set model to train mode (i.e. trainbale weights)
            else:
                pretrained_model.eval()   # set model to validation mode

            loss = 0.0
            successes = 0

            # iterate over the (training/validation) data.
            for imgs, tgts in dloaders[dset]:
                imgs = imgs.to(dvc)
                tgts = tgts.to(dvc)
                optim.zero_grad()
                
                with torch.set_grad_enabled(dset == 'train'):
                    ops = pretrained_model(imgs)
                    _, preds = torch.max(ops, 1)
                    loss_curr = loss_func(ops, tgts)
                    # backward pass only if in training mode
                    if dset == 'train':
                        loss_curr.backward()
                        optim.step()

                loss += loss_curr.item() * imgs.size(0)
                successes += torch.sum(preds == tgts.data)

            loss_epoch = loss / dset_sizes[dset]
            accuracy_epoch = successes.double() / dset_sizes[dset]

            print(f'{dset} loss in this epoch: {loss_epoch}, accuracy in this epoch: {accuracy_epoch}')
            if dset == 'val' and accuracy_epoch > accuracy:
                accuracy = accuracy_epoch
                model_weights = copy.deepcopy(pretrained_model.state_dict())
        print()

    time_delta = time.time() - start
    print(f'Training finished in {time_delta // 60}mins {time_delta % 60}secs')
    print(f'Best validation set accuracy: {accuracy}')

    # load the best model version (weights)
    pretrained_model.load_state_dict(model_weights)
    return pretrained_model

finetune_model() 函数中,我们需要输入预训练模型(即模型的架构以及权重)、损失函数、优化器和训练epoch数。与从随机初始化的权重开始不同,微调过程中使用 AlexNet 的预训练权重开始。

(4) 在开始微调(训练)模型之前,定义一个辅助函数可视化模型的预测结果:

def visualize_predictions(pretrained_model, max_num_imgs=4):
    torch.manual_seed(1)
    was_model_training = pretrained_model.training
    pretrained_model.eval()
    imgs_counter = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (imgs, tgts) in enumerate(dloaders['val']):
            imgs = imgs.to(dvc)
            tgts = tgts.to(dvc)
            ops = pretrained_model(imgs)
            _, preds = torch.max(ops, 1)
            
            for j in range(imgs.size()[0]):
                imgs_counter += 1
                ax = plt.subplot(max_num_imgs//2, 2, imgs_counter)
                ax.axis('off')
                ax.set_title(f'pred: {classes[preds[j]]} || target: {classes[tgts[j]]}')
                imageshow(imgs.cpu().data[j])

                if imgs_counter == max_num_imgs:
                    pretrained_model.train(mode=was_model_training)
                    return
        pretrained_model.train(mode=was_model_training)

(5) 使用 PyTorchtorchvision.models 加载预训练的 AlexNet 模型:

model_finetune = models.alexnet(weights=torchvision.models.AlexNet_Weights.IMAGENET1K_V1).to(device=dvc)

该模型对象包含以下两个主要组成部分:

  • features:特征提取部分,包含所有的卷积层和池化层
  • classifier:分类器部分,包含所有的全连接层,最终连接到输出层

(6) 打印特征提取部分模型结构:

print(model_finetune.features)

输出结果如下内容:

模型架构

(7) 接下来,检查分类器部分:

print(model_finetune.classifier)

输出结果如下所示:

模型架构

可以看到,预训练模型的输出层大小为 1000,但我们的微调数据集只有 2 个类别。因此,需要对此进行调整:

model_finetune.classifier[6] = nn.Linear(4096, len(classes)).to(device=dvc)

(8) 定义优化器和损失函数,然后执行训练过程:

loss_func = nn.CrossEntropyLoss()
optim_finetune = optim.SGD(model_finetune.parameters(), lr=0.0001)

# train (fine-tune) and validate the model
model_finetune = finetune_model(model_finetune, loss_func, optim_finetune, epochs=10)

输出结果如下所示:

模型训练过程

可视化模型预测结果,以查看模型是否确实从数据集中学习到了相关特征:

visualize_predictions(model_finetune)

输出结果如下所示:

分类结果

可以看到,预训练的 AlexNet 模型能够在这个相对较小的图像分类数据集上进行迁移学习。这展示了迁移学习的强大能力,以及使用 PyTorch 微调经典模型的简便性。
在下一节中,我们将讨论 AlexNet 的一个更深、更复杂的卷积神经网络——VGG 网络。

3. 使用 PyTorch 微调 VGG

我们已经学习了 AlexNet 卷积神经网络架构。接下来,我们继续探索更复杂的 CNN 模型,但构建这些模型架构的核心原则是相同的。我们将采用模块化的模型构建方法,将卷积层、池化层和全连接层组合成块(模块),然后将这些块按顺序或分支的方式堆叠起来。在本节中,我们将介绍 VGGNet
VGG (Visual Geometry Group) 与拥有 8 层和 6000 万参数的 AlexNet 相比,VGG 包含 13 层( 10 个卷积层和 3 个全连接层)和 1.38 亿参数。VGG 基本上是在 AlexNet 架构的基础上堆叠了更多层,同时使用了更小的卷积核( 2x23x3)。
因此,VGG 的创新之处在于其架构带来的 VGG 之前未有的深度,VGG 架构如下所示:

VGG架构

上述的 VGG 架构称为 VGG13,因为它包含了 13 层。其他变体包括 VGG16VGG19,分别包含 16 层和 19 层。还有另一组变体——VGG13_bnVGG16_bnVGG19_bn,其中 bn 表示这些模型还包含批归一化 (batch normalization) 层。
PyTorchtorchvision.models 提供了预训练的 VGG 模型(包括前面讨论的六种变体),这些模型是在 ImageNet 数据集上训练的。接下来,我们使用预训练的 VGG13 模型对一个小型的蜜蜂和蚂蚁数据集进行预测。

(1) 首先,导入所需库:

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms

(2) 下载数据,并设置蜜蜂和蚂蚁的数据集及数据加载器,同时定义数据预处理操作。为了对这些图像进行预测,需要下载 ImageNet 数据集的 1000 个类别标签。下载完成后,创建类别索引( 0999 )与对应类别标签之间的映射:

ddir = 'hymenoptera_data'
data_transformers = {
    'train': transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(), 
                                    transforms.Normalize([0.490, 0.449, 0.411], [0.231, 0.221, 0.230])]),
    'val': transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), 
                                      transforms.ToTensor(), 
                                      transforms.Normalize([0.490, 0.449, 0.411], [0.231, 0.221, 0.230])])}

img_data = {k: datasets.ImageFolder(os.path.join(ddir, k), data_transformers[k]) for k in ['train', 'val']}
dloaders = {k: torch.utils.data.DataLoader(img_data[k], batch_size=8, shuffle=True, num_workers=2) 
            for k in ['train', 'val']}
dset_sizes = {x: len(img_data[x]) for x in ['train', 'val']}
dvc = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import ast
with open('./imagenet1000_clsidx_to_labels.txt') as f:
    classes_data = f.read()
classes_dict = ast.literal_eval(classes_data)
print({k: classes_dict[k] for k in list(classes_dict)[:5]})

输出前五个类别的映射,如下所示:

{0: 'tench, Tinca tinca', 1: 'goldfish, Carassius auratus', 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 3: 'tiger shark, Galeocerdo cuvieri', 4: 'hammerhead, hammerhead shark'}

(3) 定义模型预测可视化函数,函数接受预训练模型对象和要进行预测的图像数量,并输出带有预测结果的图像:

def imageshow(img, text=None):
    img = img.numpy().transpose((1, 2, 0))
    avg = np.array([0.490, 0.449, 0.411])
    stddev = np.array([0.231, 0.221, 0.230])
    img = stddev * img + avg
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    if text is not None:
        plt.title(text)

def visualize_predictions(pretrained_model, max_num_imgs=4):
    was_model_training = pretrained_model.training
    pretrained_model.eval()
    imgs_counter = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (imgs, tgts) in enumerate(dloaders['val']):
            imgs = imgs.to(dvc)
            ops = pretrained_model(imgs)
            _, preds = torch.max(ops, 1)
            for j in range(imgs.size()[0]):
                imgs_counter += 1
                ax = plt.subplot(max_num_imgs//2, 2, imgs_counter)
                ax.axis('off')
                ax.set_title(f'pred: {classes_dict[int(preds[j])]}')
                imageshow(imgs.cpu().data[j])

                if imgs_counter == max_num_imgs:
                    pretrained_model.train(mode=was_model_training)
                    return
        pretrained_model.train(mode=was_model_training)

(4) 加载预训练的 VGG13 模型:

model = models.vgg13(weights=torchvision.models.VGG13_Weights.DEFAULT).to(device=dvc)

VGG13 模型大约需要 508 MB 的硬盘空间。

(5) 最后,我们使用预训练模型对蜜蜂和蚂蚁数据集进行预测:

visualize_predictions(model)

输出结果如下所示:

预测结果

经过训练的 VGG13 模型(尽管是在完全不同的数据集上训练的)能正确预测蜜蜂和蚂蚁数据集中的所有测试样本。我们可以看到,尽管模型是在与目标任务不同的数据集上训练的,但仍然能够从图像中提取相关的视觉特征。

小结

本文详细讲解了使用 PyTorch 微调经典 CNN 模型的方法。以 AlexNet 为例,介绍了其架构特点,并演示了完整的微调流程:数据加载与增强、模型结构调整、训练优化等,在蜜蜂/蚂蚁数据集上进行训练。进一步展示了 VGG13 的迁移学习能力,无需微调即可实现准确预测,验证了深度 CNN 的特征提取优势。

系列链接

PyTorch实战(1)——深度学习概述
PyTorch实战(2)——使用PyTorch构建神经网络
PyTorch实战(3)——PyTorch vs. TensorFlow详解
PyTorch实战(4)——卷积神经网络(Convolutional Neural Network,CNN)
PyTorch实战(5)——深度卷积神经网络