序言:最近做一个项目,使用到了Segformer网络,并且处理完数据集,在4000张左右的分类数据集,跑segformer_b1轻量型模型,都有了不错的效果。具体最终的指数为mIoU:93.5; mPA:95.89;
Accuracy:98.78 ,并且模型较小best.pt 大小52MB未量化,量化后15MB。推理速度也很快。于是就想来记录一下Segformer。
segformer项目链接:SegFormer - Hugging Face 机器学习平台(最下面也有测试demo代码)
segformer论文原文链接:[2105.15203] SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
SegFormer网络结构图
SegFormer主要流程简述:
1.给定一个大小为H×W×3的图像,我们首先使用 重叠式分块将其划分为大小为4×4的块。
2.Encoder将这些图像块作为输入输入到分层Transformer编码器(其中引入Efficient Self-Attention 高效自注意力)中,以获取原始图像分辨率{1/4、1/8、1/16、1/32}处的多级特征。
3.Deconder将这些多级多层特征送入MLP中用于预测分割掩码。
SegFormer主要模块
1.Encoder
主要作用:用于提取粗粒度和细粒度的分层多尺度特征。
class SegFormerStage(nn.Module):
def __init__(self, in_channels, embed_dim, num_blocks, reduction_ratio,
num_heads, expansion_ratio, patch_size, stride):
super().__init__()
# 重叠分块嵌入
self.patch_embed = OverlapPatchEmbed(
patch_size=patch_size,
stride=stride,
in_chans=in_channels,
embed_dim=embed_dim
)
# 创建Transformer块
self.blocks = nn.ModuleList([
TransformerBlock(
dim=embed_dim,
reduction_ratio=reduction_ratio,
num_heads=num_heads,
expansion_ratio=expansion_ratio
) for _ in range(num_blocks)
])
# 用于将序列转换回特征图的层
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
# 分块嵌入
x, H, W = self.patch_embed(x)
# 通过所有Transformer块
for block in self.blocks:
x = block(x)
# 归一化
x = self.norm(x)
# 将序列转换回特征图格式 [B, H, W, C]
B, N, C = x.shape
x = x.permute(0, 2, 1).view(B, C, H, W)
return x
1.1Overlap Patch Embeddings
①输入图像进行分割,使用卷积操作将输入图像分成大小为 patch_size 的块,并使用步幅为 stride 移动这些块以创建重叠块。
②然后对每个块进行一维向量化,摊平,并通过标准化层进行标准化。
tips:1.模块的输出包含一个形状为 (B, N, C) 的张量,对应(bitchsize,像素数量,嵌入维度)
2.返回 H W,这是输入图像的大小,因为在解码时需要了解原始图像的大小。
class OverlapPatchEmbed(nn.Module):
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
patch_size = (patch_size, patch_size) # 7*7
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
1.2Transformer Block
1.2.1Efficient Self-Attention 高效自注意
①引入自注意力。并且进行了序列缩减层从而降低了运算 复杂度
②复杂度由O(n^2)--->O(n^2/R),序列长度具体可缩减(N/R)
class EfficientSelfAttention(nn.Module):
def __init__(self, dim, reduction_ratio, num_heads):
super().__init__()
self.reduction_ratio = reduction_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
# 序列缩减层
self.reduction = nn.Sequential(
nn.Linear(dim, dim * reduction_ratio),
nn.LayerNorm(dim * reduction_ratio),
nn.GELU(),
nn.Linear(dim * reduction_ratio, dim // reduction_ratio)
)
# 注意力机制
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim // reduction_ratio, dim * 2) # Key和Value共享缩减
def forward(self, x):
B, N, C = x.shape # [batch, seq_len, channels]
# 1. 缩减Key序列长度
k_reduced = self.reduction(x) # [B, N/R, C/R]
v_reduced = k_reduced # 通常Value与Key共享缩减
# 2. 生成Q/K/V
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim)
kv = self.kv(k_reduced).reshape(B, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2) # [B, N/R, num_heads, head_dim]
# 3. 注意力计算(复杂度O(N²/R))
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
output = (attn @ v).transpose(1, 2).reshape(B, N, C)
return output
1.2.2Mix-FFN
①通道扩展MLP(全连接层),深度卷积注入位置信息,通道压缩MLP。
②替代传统位置编码,通过深度卷积泄露位置信息,解决测试分辨率与训练不一致时的性能下降问题。
class MixFFN(nn.Module):
def __init__(self, in_features, expansion_ratio=4, kernel_size=3):
super().__init__()
hidden_features = int(in_features * expansion_ratio)
# 1. 通道扩展MLP
self.fc1 = nn.Linear(in_features, hidden_features)
# 2. 深度卷积注入位置信息
self.dwconv = nn.Conv2d(
in_channels=hidden_features,
out_channels=hidden_features,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=hidden_features # 深度可分离卷积
)
# 3. 激活函数
self.act = nn.GELU()
# 4. 通道压缩MLP
self.fc2 = nn.Linear(hidden_features, in_features)
def forward(self, x):
# 输入形状: [batch, seq_len, channels]
B, N, C = x.shape
H, W = int(N ** 0.5), int(N ** 0.5) # 恢复2D形状
# 通道扩展
x = self.fc1(x) # [B, N, hidden_C]
# 转换为2D进行卷积
x = x.permute(0, 2, 1).view(B, -1, H, W) # [B, hidden_C, H, W]
x = self.dwconv(x) # 深度卷积泄露位置信息
x = x.flatten(2).permute(0, 2, 1) # 恢复序列 [B, N, hidden_C]
# 激活与压缩
x = self.act(x)
x = self.fc2(x) # [B, N, C]
return x
2.Decoder
主要作用:利用mlp轻量级全多层感知机解码器,直接融合这些多层次特征并预测语义分割掩膜。
class SegFormerDecoder(nn.Module):
def __init__(self, in_channels_list, unified_channels=256, num_classes=19):
super().__init__()
self.unified_channels = unified_channels
# 1. 通道对齐MLP (每个阶段独立)
self.align_mlps = nn.ModuleList([
ChannelAlignMLP(in_ch, unified_channels)
for in_ch in in_channels_list
])
# 2. 特征融合MLP
self.fusion_mlp = FeatureFusionMLP(
in_channels=4 * unified_channels,
out_channels=unified_channels
)
# 3. 语义预测MLP
self.seg_head = SegmentationHead(
in_channels=unified_channels,
num_classes=num_classes
)
def forward(self, features):
# 步骤1: 通道对齐
aligned_features = []
for i, feat in enumerate(features):
aligned = self.align_mlps[i](feat)
aligned_features.append(aligned)
# 步骤2: 上采样到1/4分辨率
target_size = aligned_features[0].shape[2:] # (H/4, W/4)
upsampled_features = []
for feat in aligned_features:
# 双线性插值上采样
up_feat = F.interpolate(
feat,
size=target_size,
mode='bilinear',
align_corners=False
)
upsampled_features.append(up_feat)
# 步骤3: 通道维度拼接
fused = torch.cat(upsampled_features, dim=1) # [B, 4*C, H/4, W/4]
# 步骤4: 特征融合
fused = self.fusion_mlp(fused) # [B, C, H/4, W/4]
# 步骤5: 语义预测
seg_mask = self.seg_head(fused) # [B, num_classes, H/4, W/4]
return seg_mask
2.1MLP Layer
①对于之前分层多尺度特征进行不同的上采样统一,然后融合不同分辨率的语义信息。
2.2MLP
①最后一个MLP用于生成像素级分类结果。
class ChannelAlignMLP(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# 1×1卷积等效于线性层,但支持2D特征图
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class FeatureFusionMLP(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# 输入通道数为4*C(4个特征图拼接)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.fc(x)
整体代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class OverlapPatchEmbed(nn.Module):
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
patch_size = (patch_size, patch_size) # 7*7
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class EfficientSelfAttention(nn.Module):
def __init__(self, dim, reduction_ratio, num_heads):
super().__init__()
self.reduction_ratio = reduction_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
# 序列缩减层
self.reduction = nn.Sequential(
nn.Linear(dim, dim * reduction_ratio),
nn.LayerNorm(dim * reduction_ratio),
nn.GELU(),
nn.Linear(dim * reduction_ratio, dim // reduction_ratio)
)
# 注意力机制
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim // reduction_ratio, dim * 2) # Key和Value共享缩减
def forward(self, x):
B, N, C = x.shape # [batch, seq_len, channels]
# 1. 缩减Key序列长度
k_reduced = self.reduction(x) # [B, N/R, C/R]
v_reduced = k_reduced # 通常Value与Key共享缩减
# 2. 生成Q/K/V
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim)
kv = self.kv(k_reduced).reshape(B, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2) # [B, N/R, num_heads, head_dim]
# 3. 注意力计算(复杂度O(N²/R))
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
output = (attn @ v).transpose(1, 2).reshape(B, N, C)
return output
class MixFFN(nn.Module):
def __init__(self, in_features, expansion_ratio=4, kernel_size=3):
super().__init__()
hidden_features = int(in_features * expansion_ratio)
# 1. 通道扩展MLP
self.fc1 = nn.Linear(in_features, hidden_features)
# 2. 深度卷积注入位置信息
self.dwconv = nn.Conv2d(
in_channels=hidden_features,
out_channels=hidden_features,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=hidden_features # 深度可分离卷积
)
# 3. 激活函数
self.act = nn.GELU()
# 4. 通道压缩MLP
self.fc2 = nn.Linear(hidden_features, in_features)
def forward(self, x):
# 输入形状: [batch, seq_len, channels]
B, N, C = x.shape
H, W = int(N ** 0.5), int(N ** 0.5) # 恢复2D形状
# 通道扩展
x = self.fc1(x) # [B, N, hidden_C]
# 转换为2D进行卷积
x = x.permute(0, 2, 1).view(B, -1, H, W) # [B, hidden_C, H, W]
x = self.dwconv(x) # 深度卷积泄露位置信息
x = x.flatten(2).permute(0, 2, 1) # 恢复序列 [B, N, hidden_C]
# 激活与压缩
x = self.act(x)
x = self.fc2(x) # [B, N, C]
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, reduction_ratio, num_heads, expansion_ratio=4):
super().__init__()
# 归一化层
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# 注意力与FFN
self.attn = EfficientSelfAttention(dim, reduction_ratio, num_heads)
self.mixffn = MixFFN(dim, expansion_ratio)
def forward(self, x):
# 残差连接1: ESA
x = x + self.attn(self.norm1(x))
# 残差连接2: Mix-FFN
x = x + self.mixffn(self.norm2(x))
return x
class SegFormerStage(nn.Module):
def __init__(self, in_channels, embed_dim, num_blocks, reduction_ratio,
num_heads, expansion_ratio, patch_size, stride):
super().__init__()
# 重叠分块嵌入
self.patch_embed = OverlapPatchEmbed(
patch_size=patch_size,
stride=stride,
in_chans=in_channels,
embed_dim=embed_dim
)
# 创建Transformer块
self.blocks = nn.ModuleList([
TransformerBlock(
dim=embed_dim,
reduction_ratio=reduction_ratio,
num_heads=num_heads,
expansion_ratio=expansion_ratio
) for _ in range(num_blocks)
])
# 用于将序列转换回特征图的层
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
# 分块嵌入
x, H, W = self.patch_embed(x)
# 通过所有Transformer块
for block in self.blocks:
x = block(x)
# 归一化
x = self.norm(x)
# 将序列转换回特征图格式 [B, H, W, C]
B, N, C = x.shape
x = x.permute(0, 2, 1).view(B, C, H, W)
return x
class ChannelAlignMLP(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# 1×1卷积等效于线性层,但支持2D特征图
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class FeatureFusionMLP(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# 输入通道数为4*C(4个特征图拼接)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.fc(x)
class SegmentationHead(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
# 1×1卷积实现像素级分类
self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
def forward(self, x):
return self.conv(x)
class SegFormerDecoder(nn.Module):
def __init__(self, in_channels_list, unified_channels=256, num_classes=19):
super().__init__()
self.unified_channels = unified_channels
# 1. 通道对齐MLP (每个阶段独立)
self.align_mlps = nn.ModuleList([
ChannelAlignMLP(in_ch, unified_channels)
for in_ch in in_channels_list
])
# 2. 特征融合MLP
self.fusion_mlp = FeatureFusionMLP(
in_channels=4 * unified_channels,
out_channels=unified_channels
)
# 3. 语义预测MLP
self.seg_head = SegmentationHead(
in_channels=unified_channels,
num_classes=num_classes
)
def forward(self, features):
# 步骤1: 通道对齐
aligned_features = []
for i, feat in enumerate(features):
aligned = self.align_mlps[i](feat)
aligned_features.append(aligned)
# 步骤2: 上采样到1/4分辨率
target_size = aligned_features[0].shape[2:] # (H/4, W/4)
upsampled_features = []
for feat in aligned_features:
# 双线性插值上采样
up_feat = F.interpolate(
feat,
size=target_size,
mode='bilinear',
align_corners=False
)
upsampled_features.append(up_feat)
# 步骤3: 通道维度拼接
fused = torch.cat(upsampled_features, dim=1) # [B, 4*C, H/4, W/4]
# 步骤4: 特征融合
fused = self.fusion_mlp(fused) # [B, C, H/4, W/4]
# 步骤5: 语义预测
seg_mask = self.seg_head(fused) # [B, num_classes, H/4, W/4]
return seg_mask
class SegFormer(nn.Module):
def __init__(self, num_classes=3, version='b0'):
super().__init__()
# 根据版本选择配置
if version == 'b0':
config = {
'stages': [
# [in_channels, embed_dim, num_blocks, reduction_ratio, num_heads, expansion_ratio, patch_size, stride]
[3, 32, 2, 8, 1, 8, 7, 4], # Stage1
[32, 64, 2, 4, 2, 8, 3, 2], # Stage2
[64, 160, 2, 2, 5, 4, 3, 2], # Stage3
[160, 256, 2, 1, 8, 4, 3, 2] # Stage4
],
'decoder_channels': 256
}
elif version == 'b1':
config = {
'stages': [
[3, 64, 2, 8, 1, 8, 7, 4],
[64, 128, 2, 4, 2, 8, 3, 2],
[128, 320, 2, 2, 5, 4, 3, 2],
[320, 512, 2, 1, 8, 4, 3, 2]
],
'decoder_channels': 256
}
else:
raise ValueError(f"Unsupported version: {version}")
# 创建编码器阶段
self.stages = nn.ModuleList()
in_channels_list = [] # 用于解码器的输入通道列表
for i, stage_config in enumerate(config['stages']):
in_channels, embed_dim, num_blocks, reduction_ratio, num_heads, expansion_ratio, patch_size, stride = stage_config
stage = SegFormerStage(
in_channels=in_channels,
embed_dim=embed_dim,
num_blocks=num_blocks,
reduction_ratio=reduction_ratio,
num_heads=num_heads,
expansion_ratio=expansion_ratio,
patch_size=patch_size,
stride=stride
)
self.stages.append(stage)
in_channels_list.append(embed_dim)
# 创建解码器
self.decoder = SegFormerDecoder(
in_channels_list=in_channels_list,
unified_channels=config['decoder_channels'],
num_classes=num_classes
)
def forward(self, x):
# 存储各阶段输出
stage_outputs = []
# 通过编码器各阶段
for i, stage in enumerate(self.stages):
# 第一个阶段输入为原始图像
if i == 0:
x = stage(x)
# 后续阶段输入为前一阶段的输出
else:
x = stage(x)
# 保存当前阶段的输出
stage_outputs.append(x)
# 通过解码器
seg_mask = self.decoder(stage_outputs)
# 上采样到原始分辨率
seg_mask = F.interpolate(seg_mask, scale_factor=4, mode='bilinear', align_corners=False)
return seg_mask
# 测试模型
if __name__ == "__main__":
# 创建模型
model = SegFormer(num_classes=3, version='b0')
print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
# 模拟输入
input_tensor = torch.randn(2, 3, 512, 512) # [batch, channels, height, width]
# 前向传播
output = model(input_tensor)
print(f"输入尺寸: {input_tensor.shape}")
print(f"输出尺寸: {output.shape}") # 应该为 [2, 3, 512, 512]
# 简单验证输出范围
print(f"输出最小值: {output.min().item():.4f}, 最大值: {output.max().item():.4f}")
# 可选: 保存模型结构图
try:
from torchviz import make_dot
dot = make_dot(output, params=dict(model.named_parameters()))
dot.render("segformer_model", format="png")
print("模型结构图已保存为 segformer_model.png")
except ImportError:
print("未安装torchviz,跳过模型结构图生成")
最后来看一下deepseek对于这个模型训练后的指数评价XSWL