神经网络-Day49

发布于:2025-06-13 ⋅ 阅读:(14) ⋅ 点赞:(0)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
 
# ====================== 配置与设备检查 ======================
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
 
# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
 
 
# ====================== 数据预处理与加载 ======================
# 训练集数据增强与归一化
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
 
# 测试集仅归一化
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
 
# 加载CIFAR10数据集
train_dataset = datasets.CIFAR10(
    root='./data', 
    train=True, 
    download=True, 
    transform=train_transform
)
test_dataset = datasets.CIFAR10(
    root='./data', 
    train=False, 
    transform=test_transform
)
 
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
 
 
# ====================== CBAM模块定义 ======================
## 通道注意力模块
class ChannelAttention(nn.Module):
    def __init__(self, in_channels: int, ratio: int = 16):
        """
        通道注意力机制
        Args:
            in_channels: 输入通道数
            ratio: 降维比例,默认16
        """
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化
        self.max_pool = nn.AdaptiveMaxPool2d(1)  # 全局最大池化
        
        # 共享全连接层实现通道降维和升维
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // ratio, bias=False),
            nn.ReLU(),
            nn.Linear(in_channels // ratio, in_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()  # 生成通道权重
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        Args:
            x: 输入特征图 (B, C, H, W)
        Returns:
            通道加权后的特征图
        """
        b, c, h, w = x.shape
        avg_feat = self.fc(self.avg_pool(x).view(b, c))  # 平均池化特征
        max_feat = self.fc(self.max_pool(x).view(b, c))  # 最大池化特征
        attn = self.sigmoid(avg_feat + max_feat).view(b, c, 1, 1)  # 融合权重
        return x * attn  # 应用通道注意力
 
 
## 空间注意力模块
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        """
        空间注意力机制
        Args:
            kernel_size: 卷积核尺寸,默认7
        """
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        Args:
            x: 输入特征图 (B, C, H, W)
        Returns:
            空间加权后的特征图
        """
        # 通道维度池化
        avg_feat = torch.mean(x, dim=1, keepdim=True)  # 平均池化
        max_feat, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化
        pool_feat = torch.cat([avg_feat, max_feat], dim=1)  # 拼接特征
        
        attn = self.conv(pool_feat)  # 卷积提取空间特征
        return x * self.sigmoid(attn)  # 应用空间注意力
 
 
## CBAM组合模块
class CBAM(nn.Module):
    def __init__(self, in_channels: int, ratio: int = 16, kernel_size: int = 7):
        """
        卷积块注意力模块 (CBAM)
        Args:
            in_channels: 输入通道数
            ratio: 通道注意力降维比例,默认16
            kernel_size: 空间注意力卷积核尺寸,默认7
        """
        super().__init__()
        self.channel_attn = ChannelAttention(in_channels, ratio)
        self.spatial_attn = SpatialAttention(kernel_size)
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播(先通道注意力,后空间注意力)
        Args:
            x: 输入特征图 (B, C, H, W)
        Returns:
            注意力加权后的特征图
        """
        x = self.channel_attn(x)
        x = self.spatial_attn(x)
        return x
 
 
# ====================== 带CBAM的CNN模型定义 ======================
class CBAM_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 卷积块1:3->32通道,带CBAM
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.cbam1 = CBAM(in_channels=32)  # 第一个CBAM模块
        
        # 卷积块2:32->64通道,带CBAM
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.cbam2 = CBAM(in_channels=64)  # 第二个CBAM模块
        
        # 卷积块3:64->128通道,带CBAM
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.cbam3 = CBAM(in_channels=128)  # 第三个CBAM模块
        
        # 全连接层
        self.fc_layers = nn.Sequential(
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, 10)
        )
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播流程
        Args:
            x: 输入图像 (B, 3, 32, 32)
        Returns:
            分类 logits (B, 10)
        """
        # 卷积块1 + CBAM1
        x = self.conv_block1(x)
        x = self.cbam1(x)
        
        # 卷积块2 + CBAM2
        x = self.conv_block2(x)
        x = self.cbam2(x)
        
        # 卷积块3 + CBAM3
        x = self.conv_block3(x)
        x = self.cbam3(x)
        
        # 展平并通过全连接层
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x
 
 
# ====================== 训练配置与函数 ======================
# 初始化模型、损失函数和优化器
model = CBAM_CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
 
 
def train(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    scheduler: optim.lr_scheduler._LRScheduler,
    device: torch.device,
    epochs: int
) -> float:
    """
    训练过程主函数
    Args:
        model: 待训练模型
        train_loader: 训练数据加载器
        test_loader: 测试数据加载器
        criterion: 损失函数
        optimizer: 优化器
        scheduler: 学习率调度器
        device: 计算设备
        epochs: 训练轮数
    Returns:
        最终测试准确率
    """
    model.train()
    train_loss_history = []
    test_loss_history = []
    train_acc_history = []
    test_acc_history = []
    all_iter_losses = []
    iter_indices = []
 
    for epoch in range(epochs):
        running_loss = 0.0
        correct_train = 0
        total_train = 0
 
        # 训练阶段
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
 
            # 记录迭代级损失
            iter_loss = loss.item()
            all_iter_losses.append(iter_loss)
            iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
            
            running_loss += iter_loss
            _, predicted = output.max(1)
            total_train += target.size(0)
            correct_train += predicted.eq(target).sum().item()
 
            # 每100批次打印进度
            if (batch_idx + 1) % 100 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                print(f"Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} "
                      f"| 单Batch损失: {iter_loss:.4f} | 平均损失: {avg_loss:.4f}")
 
        # 计算 epoch 级训练指标
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100. * correct_train / total_train
        train_loss_history.append(epoch_train_loss)
        train_acc_history.append(epoch_train_acc)
 
        # 测试阶段
        model.eval()
        test_loss = 0.0
        correct_test = 0
        total_test = 0
 
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()
 
        # 计算 epoch 级测试指标
        epoch_test_loss = test_loss / len(test_loader)
        epoch_test_acc = 100. * correct_test / total_test
        test_loss_history.append(epoch_test_loss)
        test_acc_history.append(epoch_test_acc)
 
        # 调整学习率
        scheduler.step(epoch_test_loss)
 
        # 打印 epoch 总结
        print(f"Epoch {epoch+1}/{epochs} 完成 | "
              f"Train Acc: {epoch_train_acc:.2f}% | Test Acc: {epoch_test_acc:.2f}%")
 
    # 绘制训练过程图表
    plot_iter_losses(all_iter_losses, iter_indices)
    plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
 
    return epoch_test_acc
 
 
# ====================== 绘图函数 ======================
def plot_iter_losses(losses: list, indices: list) -> None:
    """绘制每个迭代的损失曲线"""
    plt.figure(figsize=(10, 4))
    plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')
    plt.xlabel('Iteration (Batch序号)')
    plt.ylabel('Loss值')
    plt.title('训练过程中每个Batch的损失变化')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
 
 
def plot_epoch_metrics(
    train_acc: list,
    test_acc: list,
    train_loss: list,
    test_loss: list
) -> None:
    """绘制 epoch 级准确率和损失曲线"""
    epochs = range(1, len(train_acc) + 1)
    plt.figure(figsize=(12, 4))
 
    # 准确率子图
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_acc, 'b-', label='Train Accuracy')
    plt.plot(epochs, test_acc, 'r-', label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('训练与测试准确率对比')
    plt.legend()
    plt.grid(True)
 
    # 损失子图
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_loss, 'b-', label='Train Loss')
    plt.plot(epochs, test_loss, 'r-', label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss值')
    plt.title('训练与测试损失对比')
    plt.legend()
    plt.grid(True)
 
    plt.tight_layout()
    plt.show()
 
 
# ====================== 执行训练 ======================
epochs = 50
print("开始训练带CBAM的CNN模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
 
# # 如需保存模型,取消注释以下代码
# torch.save(model.state_dict(), 'cifar10_cbam_cnn_model.pth')
# print("模型已保存为: cifar10_cbam_cnn_model.pth")

@浙大疏锦行


网站公告

今日签到

点亮在社区的每一天
去签到