分割网络Segformer

发布于:2025-07-12 ⋅ 阅读:(18) ⋅ 点赞:(0)

序言:最近做一个项目,使用到了Segformer网络,并且处理完数据集,在4000张左右的分类数据集,跑segformer_b1轻量型模型,都有了不错的效果。具体最终的指数为mIoU:93.5; mPA:95.89; 
Accuracy:98.78 ,并且模型较小best.pt 大小52MB未量化,量化后15MB。推理速度也很快。于是就想来记录一下Segformer。

segformer项目链接:SegFormer - Hugging Face 机器学习平台(最下面也有测试demo代码)

segformer论文原文链接:[2105.15203] SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers

SegFormer网络结构图

SegFormer主要流程简述:

1.给定一个大小为H×W×3的图像,我们首先使用 重叠式分块将其划分为大小为4×4的块。

2.Encoder将这些图像块作为输入输入到分层Transformer编码器(其中引入Efficient Self-Attention 高效自注意力)中,以获取原始图像分辨率{1/4、1/8、1/16、1/32}处的多级特征。

3.Deconder将这些多级多层特征送入MLP中用于预测分割掩码。

SegFormer主要模块

1.Encoder

主要作用:用于提取粗粒度和细粒度的分层多尺度特征。

class SegFormerStage(nn.Module):
    def __init__(self, in_channels, embed_dim, num_blocks, reduction_ratio,
                 num_heads, expansion_ratio, patch_size, stride):
        super().__init__()
        # 重叠分块嵌入
        self.patch_embed = OverlapPatchEmbed(
            patch_size=patch_size,
            stride=stride,
            in_chans=in_channels,
            embed_dim=embed_dim
        )

        # 创建Transformer块
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                reduction_ratio=reduction_ratio,
                num_heads=num_heads,
                expansion_ratio=expansion_ratio
            ) for _ in range(num_blocks)
        ])

        # 用于将序列转换回特征图的层
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # 分块嵌入
        x, H, W = self.patch_embed(x)

        # 通过所有Transformer块
        for block in self.blocks:
            x = block(x)

        # 归一化
        x = self.norm(x)

        # 将序列转换回特征图格式 [B, H, W, C]
        B, N, C = x.shape
        x = x.permute(0, 2, 1).view(B, C, H, W)

        return x
1.1Overlap Patch Embeddings

①输入图像进行分割,使用卷积操作将输入图像分成大小为 patch_size 的块,并使用步幅为 stride 移动这些块以创建重叠块。

②然后对每个块进行一维向量化,摊平,并通过标准化层进行标准化。

tips:1.模块的输出包含一个形状为 (B, N, C) 的张量,对应(bitchsize,像素数量,嵌入维度)

        2.返回 H W,这是输入图像的大小,因为在解码时需要了解原始图像的大小。

class OverlapPatchEmbed(nn.Module):
    def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        patch_size = (patch_size, patch_size)  # 7*7
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W
1.2Transformer Block
1.2.1Efficient Self-Attention 高效自注意

①引入自注意力。并且进行了序列缩减层从而降低了运算 复杂度

②复杂度由O(n^2)--->O(n^2/R),序列长度具体可缩减(N/R)

class EfficientSelfAttention(nn.Module):
    def __init__(self, dim, reduction_ratio, num_heads):
        super().__init__()
        self.reduction_ratio = reduction_ratio
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        # 序列缩减层
        self.reduction = nn.Sequential(
            nn.Linear(dim, dim * reduction_ratio),
            nn.LayerNorm(dim * reduction_ratio),
            nn.GELU(),
            nn.Linear(dim * reduction_ratio, dim // reduction_ratio)
        )

        # 注意力机制
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim // reduction_ratio, dim * 2)  # Key和Value共享缩减

    def forward(self, x):
        B, N, C = x.shape  # [batch, seq_len, channels]

        # 1. 缩减Key序列长度
        k_reduced = self.reduction(x)  # [B, N/R, C/R]
        v_reduced = k_reduced  # 通常Value与Key共享缩减

        # 2. 生成Q/K/V
        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim)
        kv = self.kv(k_reduced).reshape(B, -1, 2, self.num_heads, self.head_dim)
        k, v = kv.unbind(2)  # [B, N/R, num_heads, head_dim]

        # 3. 注意力计算(复杂度O(N²/R))
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        output = (attn @ v).transpose(1, 2).reshape(B, N, C)

        return output
1.2.2Mix-FFN

①通道扩展MLP(全连接层),深度卷积注入位置信息,通道压缩MLP。

②替代传统位置编码,通过深度卷积泄露位置信息,解决测试分辨率与训练不一致时的性能下降问题。

class MixFFN(nn.Module):
    def __init__(self, in_features, expansion_ratio=4, kernel_size=3):
        super().__init__()
        hidden_features = int(in_features * expansion_ratio)

        # 1. 通道扩展MLP
        self.fc1 = nn.Linear(in_features, hidden_features)

        # 2. 深度卷积注入位置信息
        self.dwconv = nn.Conv2d(
            in_channels=hidden_features,
            out_channels=hidden_features,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=hidden_features  # 深度可分离卷积
        )

        # 3. 激活函数
        self.act = nn.GELU()

        # 4. 通道压缩MLP
        self.fc2 = nn.Linear(hidden_features, in_features)

    def forward(self, x):
        # 输入形状: [batch, seq_len, channels]
        B, N, C = x.shape
        H, W = int(N ** 0.5), int(N ** 0.5)  # 恢复2D形状

        # 通道扩展
        x = self.fc1(x)  # [B, N, hidden_C]

        # 转换为2D进行卷积
        x = x.permute(0, 2, 1).view(B, -1, H, W)  # [B, hidden_C, H, W]
        x = self.dwconv(x)  # 深度卷积泄露位置信息
        x = x.flatten(2).permute(0, 2, 1)  # 恢复序列 [B, N, hidden_C]

        # 激活与压缩
        x = self.act(x)
        x = self.fc2(x)  # [B, N, C]

        return x
2.Decoder

主要作用:利用mlp轻量级全多层感知机解码器,直接融合这些多层次特征并预测语义分割掩膜。

class SegFormerDecoder(nn.Module):
    def __init__(self, in_channels_list, unified_channels=256, num_classes=19):
        super().__init__()
        self.unified_channels = unified_channels

        # 1. 通道对齐MLP (每个阶段独立)
        self.align_mlps = nn.ModuleList([
            ChannelAlignMLP(in_ch, unified_channels)
            for in_ch in in_channels_list
        ])

        # 2. 特征融合MLP
        self.fusion_mlp = FeatureFusionMLP(
            in_channels=4 * unified_channels,
            out_channels=unified_channels
        )

        # 3. 语义预测MLP
        self.seg_head = SegmentationHead(
            in_channels=unified_channels,
            num_classes=num_classes
        )

    def forward(self, features):
        # 步骤1: 通道对齐
        aligned_features = []
        for i, feat in enumerate(features):
            aligned = self.align_mlps[i](feat)
            aligned_features.append(aligned)

        # 步骤2: 上采样到1/4分辨率
        target_size = aligned_features[0].shape[2:]  # (H/4, W/4)
        upsampled_features = []

        for feat in aligned_features:
            # 双线性插值上采样
            up_feat = F.interpolate(
                feat,
                size=target_size,
                mode='bilinear',
                align_corners=False
            )
            upsampled_features.append(up_feat)

        # 步骤3: 通道维度拼接
        fused = torch.cat(upsampled_features, dim=1)  # [B, 4*C, H/4, W/4]

        # 步骤4: 特征融合
        fused = self.fusion_mlp(fused)  # [B, C, H/4, W/4]

        # 步骤5: 语义预测
        seg_mask = self.seg_head(fused)  # [B, num_classes, H/4, W/4]

        return seg_mask
2.1MLP Layer

①对于之前分层多尺度特征进行不同的上采样统一,然后融合不同分辨率的语义信息。

2.2MLP

①最后一个MLP用于生成像素级分类结果。

class ChannelAlignMLP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # 1×1卷积等效于线性层,但支持2D特征图
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class FeatureFusionMLP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # 输入通道数为4*C(4个特征图拼接)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.fc(x)

整体代码:

import torch
import torch.nn as nn
import torch.nn.functional as F


class OverlapPatchEmbed(nn.Module):
    def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        patch_size = (patch_size, patch_size)  # 7*7
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W


class EfficientSelfAttention(nn.Module):
    def __init__(self, dim, reduction_ratio, num_heads):
        super().__init__()
        self.reduction_ratio = reduction_ratio
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        # 序列缩减层
        self.reduction = nn.Sequential(
            nn.Linear(dim, dim * reduction_ratio),
            nn.LayerNorm(dim * reduction_ratio),
            nn.GELU(),
            nn.Linear(dim * reduction_ratio, dim // reduction_ratio)
        )

        # 注意力机制
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim // reduction_ratio, dim * 2)  # Key和Value共享缩减

    def forward(self, x):
        B, N, C = x.shape  # [batch, seq_len, channels]

        # 1. 缩减Key序列长度
        k_reduced = self.reduction(x)  # [B, N/R, C/R]
        v_reduced = k_reduced  # 通常Value与Key共享缩减

        # 2. 生成Q/K/V
        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim)
        kv = self.kv(k_reduced).reshape(B, -1, 2, self.num_heads, self.head_dim)
        k, v = kv.unbind(2)  # [B, N/R, num_heads, head_dim]

        # 3. 注意力计算(复杂度O(N²/R))
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        output = (attn @ v).transpose(1, 2).reshape(B, N, C)

        return output


class MixFFN(nn.Module):
    def __init__(self, in_features, expansion_ratio=4, kernel_size=3):
        super().__init__()
        hidden_features = int(in_features * expansion_ratio)

        # 1. 通道扩展MLP
        self.fc1 = nn.Linear(in_features, hidden_features)

        # 2. 深度卷积注入位置信息
        self.dwconv = nn.Conv2d(
            in_channels=hidden_features,
            out_channels=hidden_features,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=hidden_features  # 深度可分离卷积
        )

        # 3. 激活函数
        self.act = nn.GELU()

        # 4. 通道压缩MLP
        self.fc2 = nn.Linear(hidden_features, in_features)

    def forward(self, x):
        # 输入形状: [batch, seq_len, channels]
        B, N, C = x.shape
        H, W = int(N ** 0.5), int(N ** 0.5)  # 恢复2D形状

        # 通道扩展
        x = self.fc1(x)  # [B, N, hidden_C]

        # 转换为2D进行卷积
        x = x.permute(0, 2, 1).view(B, -1, H, W)  # [B, hidden_C, H, W]
        x = self.dwconv(x)  # 深度卷积泄露位置信息
        x = x.flatten(2).permute(0, 2, 1)  # 恢复序列 [B, N, hidden_C]

        # 激活与压缩
        x = self.act(x)
        x = self.fc2(x)  # [B, N, C]

        return x


class TransformerBlock(nn.Module):
    def __init__(self, dim, reduction_ratio, num_heads, expansion_ratio=4):
        super().__init__()
        # 归一化层
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # 注意力与FFN
        self.attn = EfficientSelfAttention(dim, reduction_ratio, num_heads)
        self.mixffn = MixFFN(dim, expansion_ratio)

    def forward(self, x):
        # 残差连接1: ESA
        x = x + self.attn(self.norm1(x))
        # 残差连接2: Mix-FFN
        x = x + self.mixffn(self.norm2(x))
        return x


class SegFormerStage(nn.Module):
    def __init__(self, in_channels, embed_dim, num_blocks, reduction_ratio,
                 num_heads, expansion_ratio, patch_size, stride):
        super().__init__()
        # 重叠分块嵌入
        self.patch_embed = OverlapPatchEmbed(
            patch_size=patch_size,
            stride=stride,
            in_chans=in_channels,
            embed_dim=embed_dim
        )

        # 创建Transformer块
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                reduction_ratio=reduction_ratio,
                num_heads=num_heads,
                expansion_ratio=expansion_ratio
            ) for _ in range(num_blocks)
        ])

        # 用于将序列转换回特征图的层
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # 分块嵌入
        x, H, W = self.patch_embed(x)

        # 通过所有Transformer块
        for block in self.blocks:
            x = block(x)

        # 归一化
        x = self.norm(x)

        # 将序列转换回特征图格式 [B, H, W, C]
        B, N, C = x.shape
        x = x.permute(0, 2, 1).view(B, C, H, W)

        return x


class ChannelAlignMLP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # 1×1卷积等效于线性层,但支持2D特征图
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class FeatureFusionMLP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # 输入通道数为4*C(4个特征图拼接)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.fc(x)


class SegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        # 1×1卷积实现像素级分类
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class SegFormerDecoder(nn.Module):
    def __init__(self, in_channels_list, unified_channels=256, num_classes=19):
        super().__init__()
        self.unified_channels = unified_channels

        # 1. 通道对齐MLP (每个阶段独立)
        self.align_mlps = nn.ModuleList([
            ChannelAlignMLP(in_ch, unified_channels)
            for in_ch in in_channels_list
        ])

        # 2. 特征融合MLP
        self.fusion_mlp = FeatureFusionMLP(
            in_channels=4 * unified_channels,
            out_channels=unified_channels
        )

        # 3. 语义预测MLP
        self.seg_head = SegmentationHead(
            in_channels=unified_channels,
            num_classes=num_classes
        )

    def forward(self, features):
        # 步骤1: 通道对齐
        aligned_features = []
        for i, feat in enumerate(features):
            aligned = self.align_mlps[i](feat)
            aligned_features.append(aligned)

        # 步骤2: 上采样到1/4分辨率
        target_size = aligned_features[0].shape[2:]  # (H/4, W/4)
        upsampled_features = []

        for feat in aligned_features:
            # 双线性插值上采样
            up_feat = F.interpolate(
                feat,
                size=target_size,
                mode='bilinear',
                align_corners=False
            )
            upsampled_features.append(up_feat)

        # 步骤3: 通道维度拼接
        fused = torch.cat(upsampled_features, dim=1)  # [B, 4*C, H/4, W/4]

        # 步骤4: 特征融合
        fused = self.fusion_mlp(fused)  # [B, C, H/4, W/4]

        # 步骤5: 语义预测
        seg_mask = self.seg_head(fused)  # [B, num_classes, H/4, W/4]

        return seg_mask


class SegFormer(nn.Module):
    def __init__(self, num_classes=3, version='b0'):
        super().__init__()

        # 根据版本选择配置
        if version == 'b0':
            config = {
                'stages': [
                    # [in_channels, embed_dim, num_blocks, reduction_ratio, num_heads, expansion_ratio, patch_size, stride]
                    [3, 32, 2, 8, 1, 8, 7, 4],  # Stage1
                    [32, 64, 2, 4, 2, 8, 3, 2],  # Stage2
                    [64, 160, 2, 2, 5, 4, 3, 2],  # Stage3
                    [160, 256, 2, 1, 8, 4, 3, 2]  # Stage4
                ],
                'decoder_channels': 256
            }
        elif version == 'b1':
            config = {
                'stages': [
                    [3, 64, 2, 8, 1, 8, 7, 4],
                    [64, 128, 2, 4, 2, 8, 3, 2],
                    [128, 320, 2, 2, 5, 4, 3, 2],
                    [320, 512, 2, 1, 8, 4, 3, 2]
                ],
                'decoder_channels': 256
            }
        else:
            raise ValueError(f"Unsupported version: {version}")

        # 创建编码器阶段
        self.stages = nn.ModuleList()
        in_channels_list = []  # 用于解码器的输入通道列表

        for i, stage_config in enumerate(config['stages']):
            in_channels, embed_dim, num_blocks, reduction_ratio, num_heads, expansion_ratio, patch_size, stride = stage_config
            stage = SegFormerStage(
                in_channels=in_channels,
                embed_dim=embed_dim,
                num_blocks=num_blocks,
                reduction_ratio=reduction_ratio,
                num_heads=num_heads,
                expansion_ratio=expansion_ratio,
                patch_size=patch_size,
                stride=stride
            )
            self.stages.append(stage)
            in_channels_list.append(embed_dim)

        # 创建解码器
        self.decoder = SegFormerDecoder(
            in_channels_list=in_channels_list,
            unified_channels=config['decoder_channels'],
            num_classes=num_classes
        )

    def forward(self, x):
        # 存储各阶段输出
        stage_outputs = []

        # 通过编码器各阶段
        for i, stage in enumerate(self.stages):
            # 第一个阶段输入为原始图像
            if i == 0:
                x = stage(x)
            # 后续阶段输入为前一阶段的输出
            else:
                x = stage(x)

            # 保存当前阶段的输出
            stage_outputs.append(x)

        # 通过解码器
        seg_mask = self.decoder(stage_outputs)

        # 上采样到原始分辨率
        seg_mask = F.interpolate(seg_mask, scale_factor=4, mode='bilinear', align_corners=False)

        return seg_mask


# 测试模型
if __name__ == "__main__":
    # 创建模型
    model = SegFormer(num_classes=3, version='b0')
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

    # 模拟输入
    input_tensor = torch.randn(2, 3, 512, 512)  # [batch, channels, height, width]

    # 前向传播
    output = model(input_tensor)

    print(f"输入尺寸: {input_tensor.shape}")
    print(f"输出尺寸: {output.shape}")  # 应该为 [2, 3, 512, 512]

    # 简单验证输出范围
    print(f"输出最小值: {output.min().item():.4f}, 最大值: {output.max().item():.4f}")

    # 可选: 保存模型结构图
    try:
        from torchviz import make_dot

        dot = make_dot(output, params=dict(model.named_parameters()))
        dot.render("segformer_model", format="png")
        print("模型结构图已保存为 segformer_model.png")
    except ImportError:
        print("未安装torchviz,跳过模型结构图生成")

最后来看一下deepseek对于这个模型训练后的指数评价XSWL


网站公告

今日签到

点亮在社区的每一天
去签到