python打卡第50天

发布于:2025-06-13 ⋅ 阅读:(23) ⋅ 点赞:(0)

知识点回顾:

  1. resnet结构解析
  2. CBAM放置位置的思考
  3. 针对预训练模型的训练策略
    1. 差异化学习率
    2. 三阶段微调

ps:今日的代码训练时长较长,3080ti大概需要40min的训练时长

作业:

  1. 好好理解下resnet18的模型结构
  2. 尝试对vgg16+cbam进行微调策略
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time

# 通道注意力机制
class ChannelAttentionModule(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_features = self.mlp(self.avg_pool(x).view(b, c))
        max_features = self.mlp(self.max_pool(x).view(b, c))
        weights = self.sigmoid(avg_features + max_features).view(b, c, 1, 1)
        return x * weights

# 空间注意力机制
class SpatialAttentionModule(nn.Module):
    def __init__(self, kernel=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel, padding=kernel//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_features = torch.mean(x, dim=1, keepdim=True)
        max_features, _ = torch.max(x, dim=1, keepdim=True)
        combined = torch.cat([avg_features, max_features], dim=1)
        spatial_weights = self.sigmoid(self.conv(combined))
        return x * spatial_weights

# 结合通道和空间注意力
class CBAMBlock(nn.Module):
    def __init__(self, channels, reduction=16, kernel=7):
        super().__init__()
        self.channel_attention = ChannelAttentionModule(channels, reduction)
        self.spatial_attention = SpatialAttentionModule(kernel)

    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

# 配置绘图环境
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False

# 设置计算设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 数据预处理
train_augmentation = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 加载数据集
cifar_train = datasets.CIFAR10('./data', train=True, download=True, transform=train_augmentation)
cifar_test = datasets.CIFAR10('./data', train=False, transform=test_transform)
train_loader = DataLoader(cifar_train, batch_size=64, shuffle=True)
test_loader = DataLoader(cifar_test, batch_size=64, shuffle=False)

# 增强型ResNet模型
class EnhancedResNet(nn.Module):
    def __init__(self, num_classes=10, pretrained=True, reduction=16, kernel=7):
        super().__init__()
        # 加载预训练模型
        base_model = models.resnet18(pretrained=pretrained)
        
        # 调整输入层适应小尺寸图像
        base_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base_model.maxpool = nn.Identity()
        
        # 添加注意力模块
        self.attention1 = CBAMBlock(64, reduction, kernel)
        self.attention2 = CBAMBlock(128, reduction, kernel)
        self.attention3 = CBAMBlock(256, reduction, kernel)
        self.attention4 = CBAMBlock(512, reduction, kernel)
        
        # 替换分类层
        base_model.fc = nn.Linear(512, num_classes)
        self.base = base_model

    def forward(self, x):
        x = self.base.conv1(x)
        x = self.base.bn1(x)
        x = self.base.relu(x)
        
        # 残差块与注意力模块交替
        x = self.base.layer1(x)
        x = self.attention1(x)
        
        x = self.base.layer2(x)
        x = self.attention2(x)
        
        x = self.base.layer3(x)
        x = self.attention3(x)
        
        x = self.base.layer4(x)
        x = self.attention4(x)
        
        # 分类输出
        x = self.base.avgpool(x)
        x = torch.flatten(x, 1)
        return self.base.fc(x)

# 配置模型训练参数
def configure_optimizer(model, stage):
    if stage == 1:
        for param in model.parameters():
            param.requires_grad = False
        for name, param in model.named_parameters():
            if "attention" in name or "fc" in name:
                param.requires_grad = True
        return optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
    
    elif stage == 2:
        for name, param in model.named_parameters():
            if "layer3" in name or "layer4" in name or "attention" in name or "fc" in name:
                param.requires_grad = True
        return optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    
    else:  # stage 3
        for param in model.parameters():
            param.requires_grad = True
        return optim.Adam(model.parameters(), lr=1e-5)

# 训练和验证过程
def run_training(model, criterion, train_loader, test_loader, device, total_epochs):
    batch_losses = []
    epoch_losses = []
    train_acc_history = []
    test_acc_history = []
    
    optimizer = None
    
    for epoch in range(1, total_epochs + 1):
        start_time = time.time()
        
        # 分阶段配置优化器
        if epoch == 1:
            print("\n" + "="*50 + "\n阶段1:训练注意力模块和分类层\n" + "="*50)
            optimizer, lr = configure_optimizer(model, 1)
        elif epoch == 6:
            print("\n" + "="*50 + "\n阶段2:解冻高层卷积层\n" + "="*50)
            optimizer, lr = configure_optimizer(model, 2)
        elif epoch == 21:
            print("\n" + "="*50 + "\n阶段3:全局微调\n" + "="*50)
            optimizer, lr = configure_optimizer(model, 3)
        
        # 训练阶段
        model.train()
        running_loss = 0.0
        correct = 0
        total_samples = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # 记录损失
            current_loss = loss.item()
            batch_losses.append(current_loss)
            running_loss += current_loss
            
            # 计算准确率
            _, predicted = outputs.max(1)
            total_samples += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # 定期打印进度
            if (batch_idx + 1) % 100 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                print(f'周期: {epoch}/{total_epochs} | 批次: {batch_idx+1}/{len(train_loader)} '
                      f'| 当前损失: {current_loss:.4f} | 平均损失: {avg_loss:.4f}')
        
        # 计算训练统计
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total_samples
        epoch_losses.append(train_loss)
        train_acc_history.append(train_acc)
        
        # 验证阶段
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                test_loss += criterion(outputs, targets).item()
                
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()
        
        test_loss /= len(test_loader)
        test_acc = 100. * test_correct / test_total
        test_acc_history.append(test_acc)
        
        # 打印周期结果
        epoch_time = time.time() - start_time
        print(f'周期 {epoch}/{total_epochs} 完成 | 用时: {epoch_time:.2f}s | '
              f'训练准确率: {train_acc:.2f}% | 测试准确率: {test_acc:.2f}%')
    
    # 可视化结果
    visualize_results(batch_losses, epoch_losses, train_acc_history, test_acc_history)
    return test_acc_history[-1]

# 结果可视化
def visualize_results(batch_losses, epoch_losses, train_acc, test_acc):
    plt.figure(figsize=(15, 5))
    
    # 批次损失
    plt.subplot(1, 3, 1)
    plt.plot(batch_losses, 'b-', alpha=0.7)
    plt.xlabel('训练批次')
    plt.ylabel('损失值')
    plt.title('批次训练损失')
    plt.grid(True)
    
    # 周期损失
    plt.subplot(1, 3, 2)
    plt.plot(epoch_losses, 'r-')
    plt.xlabel('训练周期')
    plt.ylabel('平均损失')
    plt.title('周期训练损失')
    plt.grid(True)
    
    # 准确率曲线
    plt.subplot(1, 3, 3)
    plt.plot(train_acc, 'g-', label='训练准确率')
    plt.plot(test_acc, 'b-', label='测试准确率')
    plt.xlabel('训练周期')
    plt.ylabel('准确率 (%)')
    plt.title('训练和测试准确率')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# 主执行流程
if __name__ == "__main__":
    # 初始化模型
    net = EnhancedResNet().to(device)
    loss_fn = nn.CrossEntropyLoss()
    
    print("开始训练增强型ResNet模型...")
    final_acc = run_training(net, loss_fn, train_loader, test_loader, device, 50)
    print(f"训练完成! 最终测试准确率: {final_acc:.2f}%")
    
    # 保存模型
    torch.save(net.state_dict(), 'enhanced_resnet_cifar10.pth')
    print("模型已保存至: enhanced_resnet_cifar10.pth")

 @浙大疏锦行