python打卡第47天

发布于:2025-06-08 ⋅ 阅读:(13) ⋅ 点赞:(0)

昨天代码中注意力热图的部分顺移至今天

知识点回顾:

热力图

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()  
        
        # ---------------------- 第一个卷积块 ----------------------
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        # 新增:插入通道注意力模块(SE模块)
        self.ca1 = ChannelAttention(in_channels=32, reduction_ratio=16)  
        self.pool1 = nn.MaxPool2d(2, 2)  
        
        # ---------------------- 第二个卷积块 ----------------------
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        # 新增:插入通道注意力模块(SE模块)
        self.ca2 = ChannelAttention(in_channels=64, reduction_ratio=16)  
        self.pool2 = nn.MaxPool2d(2)  
        
        # ---------------------- 第三个卷积块 ----------------------
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        # 新增:插入通道注意力模块(SE模块)
        self.ca3 = ChannelAttention(in_channels=128, reduction_ratio=16)  
        self.pool3 = nn.MaxPool2d(2)  
        
        # ---------------------- 全连接层(分类器) ----------------------
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        # ---------- 卷积块1处理 ----------
        x = self.conv1(x)       
        x = self.bn1(x)         
        x = self.relu1(x)       
        x = self.ca1(x)  # 应用通道注意力
        x = self.pool1(x)       
        
        # ---------- 卷积块2处理 ----------
        x = self.conv2(x)       
        x = self.bn2(x)         
        x = self.relu2(x)       
        x = self.ca2(x)  # 应用通道注意力
        x = self.pool2(x)       
        
        # ---------- 卷积块3处理 ----------
        x = self.conv3(x)       
        x = self.bn3(x)         
        x = self.relu3(x)       
        x = self.ca3(x)  # 应用通道注意力
        x = self.pool3(x)       
        
        # ---------- 展平与全连接层 ----------
        x = x.view(-1, 128 * 4 * 4)  
        x = self.fc1(x)           
        x = self.relu3(x)         
        x = self.dropout(x)       
        x = self.fc2(x)           
        
        return x  

# 重新初始化模型,包含通道注意力模块
model = CNN()
model = model.to(device)  # 将模型移至GPU(如果可用)

criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器

# 引入学习率调度器,在训练过程中动态调整学习率--训练初期使用较大的 LR 快速降低损失,训练后期使用较小的 LR 更精细地逼近全局最优解。
# 在每个 epoch 结束后,需要手动调用调度器来更新学习率,可以在训练过程中调用 scheduler.step()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,        # 指定要控制的优化器(这里是Adam)
    mode='min',       # 监测的指标是"最小化"(如损失函数)
    patience=3,       # 如果连续3个epoch指标没有改善,才降低LR
    factor=0.5        # 降低LR的比例(新LR = 旧LR × 0.5)
# 训练模型(复用原有的train函数)
print("开始训练带通道注意力的CNN模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=50)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
# 可视化空间注意力热力图(显示模型关注的图像区域)
def visualize_attention_map(model, test_loader, device, class_names, num_samples=3):
    """可视化模型的注意力热力图,展示模型关注的图像区域"""
    model.eval()  # 设置为评估模式
    
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            if i >= num_samples:  # 只可视化前几个样本
                break
                
            images, labels = images.to(device), labels.to(device)
            
            # 创建一个钩子,捕获中间特征图
            activation_maps = []
            
            def hook(module, input, output):
                activation_maps.append(output.cpu())
            
            # 为最后一个卷积层注册钩子(获取特征图)
            hook_handle = model.conv3.register_forward_hook(hook)
            
            # 前向传播,触发钩子
            outputs = model(images)
            
            # 移除钩子
            hook_handle.remove()
            
            # 获取预测结果
            _, predicted = torch.max(outputs, 1)
            
            # 获取原始图像
            img = images[0].cpu().permute(1, 2, 0).numpy()
            # 反标准化处理
            img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)
            img = np.clip(img, 0, 1)
            
            # 获取激活图(最后一个卷积层的输出)
            feature_map = activation_maps[0][0].cpu()  # 取第一个样本
            
            # 计算通道注意力权重(使用SE模块的全局平均池化)
            channel_weights = torch.mean(feature_map, dim=(1, 2))  # [C]
            
            # 按权重对通道排序
            sorted_indices = torch.argsort(channel_weights, descending=True)
            
            # 创建子图
            fig, axes = plt.subplots(1, 4, figsize=(16, 4))
            
            # 显示原始图像
            axes[0].imshow(img)
            axes[0].set_title(f'原始图像\n真实: {class_names[labels[0]]}\n预测: {class_names[predicted[0]]}')
            axes[0].axis('off')
            
            # 显示前3个最活跃通道的热力图
            for j in range(3):
                channel_idx = sorted_indices[j]
                # 获取对应通道的特征图
                channel_map = feature_map[channel_idx].numpy()
                # 归一化到[0,1]
                channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min() + 1e-8)
                
                # 调整热力图大小以匹配原始图像
                from scipy.ndimage import zoom
                heatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2]))
                
                # 显示热力图
                axes[j+1].imshow(img)
                axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet')
                axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}')
                axes[j+1].axis('off')
            
            plt.tight_layout()
            plt.show()

# 调用可视化函数
visualize_attention_map(model, test_loader, device, class_names, num_samples=3)

这个注意力热图是通过构子机制: `register_forward_hook` 捕获最后一个卷积层(`conv3`)的输出特征图。  

  1. **通道权重计算**:对特征图的每个通道进行全局平均池化,得到通道重要性权重。  

  2. **热力图生成**:将高权重通道的特征图缩放至原始图像尺寸,与原图叠加显示。


 

热力图(红色表示高关注,蓝色表示低关注)半透明覆盖在原图上。主要从以下方面理解:

- **高关注区域**(红色):模型认为对分类最重要的区域。  

  例如:  

  - 在识别“狗”时,热力图可能聚焦狗的面部、身体轮廓或特征性纹理。  

  - 若热力图错误聚焦背景(如红色区域在无关物体上),可能表示模型过拟合或训练不足。

**多通道对比**

- **不同通道关注不同特征**:  

  例如:  

  - 通道1可能关注整体轮廓,通道2关注纹理细节,通道3关注颜色分布。  

  - 结合多个通道的热力图,可全面理解模型的决策逻辑。

可以帮助解释

   - 检查模型是否关注正确区域(如识别狗时,是否聚焦狗而非背景)。  

   - 发现数据标注问题(如标签错误、图像噪声)。

   - 向非技术人员解释模型决策依据(如“模型认为这是狗,因为关注了眼睛和嘴巴”)。