使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题

发布于:2025-02-10 ⋅ 阅读:(104) ⋅ 点赞:(0)

在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著


1. 小尺寸图像如何加剧样本不均衡?

(1) 局部裁剪导致类别分布偏差
  • 问题:遥感图像中某些类别(如道路、建筑)可能稀疏分布。小尺寸裁剪后,部分训练样本可能完全不含某些类别(例如一块纯农田的补丁),导致模型对这些类别缺乏学习机会。
  • 示例
    • 原图中“道路”占比5%,若裁剪为 256x256 的小图,部分小图中可能完全无道路像素。
    • 极端情况下,某些类别可能仅在极少数小图中出现,形成“长尾分布”。
(2) 批次内类别覆盖不足
  • 问题:小尺寸图像的批训练(batch training)中,若单个批次内缺少某些类别,梯度更新会偏向多数类。
  • 示例:若一个batch中80%的补丁以“植被”为主,模型会倾向于将模糊区域预测为植被。
(3) 像素级不平衡放大
  • 问题:即使原图类别均衡,小尺寸裁剪可能导致局部像素比例失衡。
    • 例如,原图中“水体”占10%,但某个小图中水体可能占90%(河流区域)或0%(干旱区域)。

2. 样本不均衡的典型影响

  • 模型偏向多数类:对高频类别(如植被、背景)过拟合,低频类别(如车辆、道路)漏检。
  • 边界模糊:模型对类别交界处的预测置信度低,导致分割边缘不连续。
  • 评估指标失真:全局指标(如整体准确率)虚高,但关键类别(如灾害损毁区域)的IoU/F1值极低。

3. 针对小尺寸图像的解决方案

(1) 数据层面的优化
  • 定向裁剪(Guided Cropping)
    • 根据类别分布优先裁剪包含稀有类别的小图。
    • 工具:使用滑动窗口统计每个候选补丁的类别比例,筛选包含目标类别的补丁。
  • 过采样(Oversampling)
    • 对包含稀有类别的小图增加采样概率。
    • 例如:若某小图中含“道路”,则其在训练集中的出现次数增加3倍。
  • 数据增强强化
    • 对小图中稀有类别区域进行针对性增强:
      • 局部旋转、缩放、亮度调整(避免全局变换导致稀有目标失真)。
      • 复制-粘贴增强(Copy-Paste):将稀有目标粘贴到其他背景中(如将车辆粘贴到农田补丁上)。
(2) 损失函数设计
  • 加权交叉熵(Weighted Cross-Entropy)
    • 根据类别像素频率反向加权,例如权重与类别频率成反比:
      weight = 1 / (class_freq + epsilon)  # 防止除零
      
  • Focal Loss
    • 抑制易分类样本(如背景)的损失贡献,聚焦难样本(如小目标):
      loss = -α * (1 - p)^γ * log(p)  # α平衡类别,γ聚焦难样本
      
  • Dice Loss / Tversky Loss
    • 直接优化分割重叠区域(IoU),对类别不平衡更鲁棒:
      Dice Loss = 1 - (2*|X∩Y|) / (|X| + |Y|)
      Tversky Loss = 1 - (|X∩Y|) / (|X∩Y| + α|X-Y| + β|Y-X|)  # 调整α,β权衡假阳/假阴
      
(3) 模型架构改进
  • 上下文感知模块
    • 使用空洞卷积(Dilated Convolution)或注意力机制(如SE Block、Non-local Networks),增强模型对稀疏目标的捕捉能力。
  • 多尺度特征融合
    • 通过金字塔池化(PSPNet)或U-Net++结构,融合不同尺度的特征,缓解因小尺寸输入丢失的上下文信息。
  • 辅助监督(Auxiliary Supervision)
    • 在中间层添加辅助损失函数,强制模型关注细粒度特征。
(4) 训练策略调整
  • 小批次大迭代
    • 使用小batch size但增加迭代次数,确保稀有类别在多个epoch中被充分学习。
  • 动态类别权重
    • 根据当前batch内的类别分布实时调整损失权重。
  • 困难样本挖掘(Hard Example Mining)
    • 在每个epoch后,筛选对稀有类别预测误差大的样本,下一轮训练中增加其采样概率。

4. 实验验证建议

  • 监控类别指标:除了整体准确率,跟踪每个类别的IoU、F1-score。
  • 可视化错误样本:检查模型在稀有类别上的失败案例,针对性优化数据或模型。
  • 消融实验:对比不同损失函数、数据增强策略的效果。

小尺寸图像训练会放大样本不均衡问题,但通过定向数据采样、损失函数优化、模型结构改进三者结合,可显著缓解影响。关键是根据任务特点(如目标大小、类别分布)选择组合策略,例如:

  • 稀疏小目标:Focal Loss + Copy-Paste增强 + 空洞卷积。
  • 长尾分布:加权交叉熵 + 过采样 + 动态类别权重。

在 PyTorch 中,虽然没有直接解决语义分割样本不均衡的“万能模块”,但可以通过组合现有模块社区成熟库高效实现解决方案。


1. 数据层面:加权采样与增强

(1) 加权随机采样(WeightedRandomSampler)

PyTorch 内置 WeightedRandomSampler,可对包含稀有类别的图像补丁过采样:

import numpy as np

def compute_weight_for_patch(patch):
    image, mask = patch
    # 假设 mask 是一个二维数组,每个像素值表示类别标签
    # 计算每个类别的像素数量
    class_counts = np.bincount(mask.flatten())
    
    # 计算总像素数量
    total_pixels = mask.size
    
    # 计算每个类别的比例
    class_ratios = class_counts / total_pixels
    
    # 计算所有类别的权重
    class_weights = 1.0 / (class_ratios + 1e-6)  # 避免除以零,添加一个小的常数
    
    # 应用 sigmoid 函数
    class_weights = 1.0 / (1.0 + np.exp(-class_weights))
    
    # 计算样本的权重
    sample_weight = np.sum(class_weights)
    print("Total samples weights:", sample_weight)
    
    return class_weights

from torch.utils.data import WeightedRandomSampler

# 假设 dataset 返回 (image, mask),且每个样本有一个权重 weight
weights = [compute_weight_for_patch(patch) for patch in dataset]  # 根据补丁中稀有类别比例计算权重
sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
(2) 数据增强库(Albumentations)

Albumentations 提供针对分割任务的增强,支持对特定类别区域增强:

import albumentations as A

transform = A.Compose([
    A.RandomCrop(256, 256),
    A.OneOf([
        A.RandomRotate90(),
        A.HorizontalFlip(),
        A.VerticalFlip()
    ]),
    A.RandomBrightnessContrast(p=0.5),
    # 对特定类别区域增强(如仅增强“车辆”区域)
    A.RandomCropNearBBox(p=0.5, max_part_shift=0.3)
])

2. 损失函数:直接调用社区实现

(1) Focal Loss

使用 torchvision.ops 或第三方库:

# 使用 torchvision(需 0.10+ 版本)
from torchvision.ops import sigmoid_focal_loss

loss = sigmoid_focal_loss(outputs, targets, alpha=0.25, gamma=2, reduction="mean")

# 或自定义多类别 Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction="none")
        pt = torch.exp(-ce_loss)
        loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return loss.mean()
(2) Dice Loss

社区标准实现(或使用 segmentation_models_pytorch 库):

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = F.softmax(inputs, dim=1)
        targets = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2)
        intersection = (inputs * targets).sum()
        union = inputs.sum() + targets.sum()
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice
(3) 直接调用 segmentation_models_pytorch 损失函数
import segmentation_models_pytorch as smp

loss = smp.losses.DiceLoss(mode="multiclass", classes=[0, 1, 2])  # 指定关注类别
loss = smp.losses.FocalLoss(mode="multiclass", normalized=True)   # 归一化版本

3. 模型层面:集成注意力与多尺度模块

(1) 使用预建模型库

segmentation_models_pytorch(SMP)提供即用的模型和模块:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=5,
    decoder_attention_type="scse",  # 添加空间-通道注意力
)
(2) 空洞卷积(Dilated Convolution)

直接使用 PyTorch 的 Conv2d 实现:

class DilatedConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dilation_rate=2):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, 
            out_channels, 
            kernel_size=3, 
            padding=dilation_rate, 
            dilation=dilation_rate
        )
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

# 在 U-Net 的 decoder 中插入空洞卷积块

4. 类别权重计算工具

(1) 自动计算类别权重
from sklearn.utils.class_weight import compute_class_weight

# 统计训练集所有像素的类别分布
class_counts = np.bincount(all_pixel_labels.flatten())
class_weights = compute_class_weight(
    class_weight="balanced", 
    classes=np.arange(num_classes), 
    y=all_pixel_labels.flatten()
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

# 在损失函数中使用
criterion = nn.CrossEntropyLoss(weight=class_weights)

5. 完整 Pipeline 示例

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import segmentation_models_pytorch as smp
import albumentations as A

# 1. 定义数据集和采样器
dataset = YourDataset(transform=albumentations_transform)
weights = compute_patch_weights(dataset)  # 根据补丁中目标类别比例计算
sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)

# 2. 定义模型和损失
model = smp.Unet(encoder_name="resnet34", classes=5, decoder_attention_type="scse")
criterion = smp.losses.DiceLoss(mode="multiclass") + smp.losses.FocalLoss(mode="multiclass")

# 3. 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(100):
    for images, masks in dataloader:
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

关键工具总结

问题类型 PyTorch 原生支持 推荐第三方库(直接调用)
数据采样 WeightedRandomSampler Albumentations(定向增强)
损失函数 自定义(需手写) segmentation_models_pytorch.losses
模型结构 手动添加模块(空洞卷积、注意力) segmentation_models_pytorch 预建模型
类别权重计算 sklearn.utils.class_weight 内置自动统计工具(如 SMP 数据集类)

注意事项

  1. 灵活组合策略:例如同时使用 WeightedRandomSamplerFocal Loss 可能过度偏向少数类,需通过实验调整。
  2. 监控类别指标:使用 torchmetrics 库计算每个类别的 IoU:
    from torchmetrics import JaccardIndex
    iou = JaccardIndex(num_classes=5, task="multiclass")
    iou.update(outputs, targets)
    print(f"IoU: {iou.compute()}")
    
  3. 混合精度训练:使用 torch.cuda.amp 加速训练,缓解显存压力:
    scaler = torch.cuda.amp.GradScaler()
    with torch.cuda.amp.autocast():
        outputs = model(images)
        loss = criterion(outputs, masks)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    

网站公告

今日签到

点亮在社区的每一天
去签到