如何对遥感图像进行目标检测?
1. 遥感图像目标检测的基本流程
遥感图像目标检测是从卫星、无人机等遥感影像中自动识别和定位感兴趣目标(如建筑、车辆、机场等)的技术,核心流程包括:
- 数据预处理:辐射校正(消除传感器误差)、几何校正(修正地形/投影偏差)、裁剪/下采样(处理高分辨率数据);
- 特征提取:通过卷积神经网络(CNN)提取图像的纹理、形状、光谱等特征;
- 目标定位与分类:利用检测算法(如锚框机制、Transformer等)预测目标的位置(边界框)和类别;
- 后处理:非极大值抑制(NMS)去除冗余框,提升检测精度。
2. 遥感图像目标检测的难点
与自然图像(如手机拍摄的照片)相比,遥感图像的特殊性带来了独特挑战:
- 目标尺度差异极大:同一幅图像中可能同时存在千米级的机场和米级的车辆,尺度跨度可达1000倍以上;
- 目标方向任意:遥感图像为俯视视角,目标(如车辆、船只)可沿任意方向旋转,轴对齐边框(Axis-Aligned BBox)会引入大量背景噪声;
- 小目标密集分布:如停车场的车辆、城区的小型建筑,往往密集排列且像素占比低(可能仅10×10像素);
- 背景复杂且干扰强:地物(如道路、植被)与目标可能具有相似光谱/纹理特征(如车辆与路面颜色接近);
- 数据标注成本高:遥感图像分辨率高(单幅可达GB级),且专业标注需要领域知识(如区分“飞机”和“直升机”)。
3. 解决方案
针对上述难点,主流技术方案包括:
难点 | 解决方案 |
---|---|
尺度差异大 | 多尺度特征融合(如FPN)、动态锚框生成(根据图像内容自适应调整锚框尺度) |
目标方向任意 | 旋转边框(Rotated BBox)回归(如R2CNN、RRPN)、角度感知的损失函数 |
小目标密集 | 高分辨率特征保留(如CSPNet)、超分辨率重建(提升小目标细节)、密集检测头 |
背景复杂 | 注意力机制(如CBAM)抑制背景噪声、多模态融合(结合光谱/雷达数据) |
标注成本高 | 半监督学习(利用少量标注数据训练)、迁移学习(从自然图像模型迁移权重) |
PyTorch实现遥感图像目标检测(简化版)
以下实现一个支持旋转边框的简化检测模型,基于ResNet50+FPN提取特征,使用旋转锚框预测目标的位置(x, y, w, h, θ)和类别。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import numpy as np
import cv2
import os
from PIL import Image
# -------------------------- 1. 数据集定义 --------------------------
class RemoteSensingDataset(Dataset):
def __init__(self, img_dir, ann_dir, img_size=(512, 512)):
"""
遥感数据集初始化
:param img_dir: 图像文件夹路径
:param ann_dir: 标注文件路径(每个图像对应一个txt,每行格式:x_center y_center w h angle class)
:param img_size: 图像resize尺寸
"""
self.img_dir = img_dir
self.ann_dir = ann_dir
self.img_size = img_size
self.img_names = [f for f in os.listdir(img_dir) if f.endswith(('png', 'jpg'))]
self.transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet均值
])
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_name = self.img_names[idx]
img_path = os.path.join(self.img_dir, img_name)
ann_path = os.path.join(self.ann_dir, img_name.replace('.png', '.txt').replace('.jpg', '.txt'))
# 读取图像
img = Image.open(img_path).convert('RGB')
img = self.transform(img)
# 读取标注(旋转框:x_center, y_center, w, h, angle(弧度), class)
boxes = []
labels = []
if os.path.exists(ann_path):
with open(ann_path, 'r') as f:
for line in f.readlines():
xc, yc, w, h, angle, cls = map(float, line.strip().split())
# 归一化坐标转绝对坐标
xc *= self.img_size[0]
yc *= self.img_size[1]
w *= self.img_size[0]
h *= self.img_size[1]
boxes.append([xc, yc, w, h, angle])
labels.append(cls)
boxes = torch.tensor(boxes, dtype=torch.float32) # (N, 5)
labels = torch.tensor(labels, dtype=torch.long) # (N,)
return img, boxes, labels
# -------------------------- 2. 模型结构 --------------------------
class FPN(nn.Module):
"""特征金字塔网络(FPN):融合多尺度特征"""
def __init__(self, in_channels_list, out_channels):
super(FPN, self).__init__()
self.lateral_convs = nn.ModuleList() # 横向卷积(降维到out_channels)
self.fpn_convs = nn.ModuleList() # 输出卷积(消除 aliasing effect)
for in_channels in in_channels_list:
self.lateral_convs.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
self.fpn_convs.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
def forward(self, x):
"""
:param x: 从backbone输出的多尺度特征 [C1, C2, C3, C4](分辨率从高到低)
:return: 融合后的特征 [P1, P2, P3, P4](同输入尺度)
"""
# 横向连接 + 上采样
laterals = [lateral_conv(xi) for lateral_conv, xi in zip(self.lateral_convs, x)]
# 从最高层开始融合
outs = [laterals[-1]]
for i in range(len(laterals)-2, -1, -1):
# 上采样高层特征并与当前层融合
upsample = F.interpolate(outs[-1], size=laterals[i].shape[2:], mode='bilinear', align_corners=True)
outs.append(laterals[i] + upsample)
# 反转顺序(从低层到高层)
outs = outs[::-1]
# 输出卷积
outs = [fpn_conv(out) for fpn_conv, out in zip(self.fpn_convs, outs)]
return outs
class RotatedDetectionHead(nn.Module):
"""旋转框检测头:预测类别和旋转框参数(x, y, w, h, θ)"""
def __init__(self, in_channels, num_classes, num_anchors=9):
super(RotatedDetectionHead, self).__init__()
# 分类头(每个锚框对应num_classes个类别)
self.cls_head = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
)
# 回归头(每个锚框对应5个参数:x, y, w, h, θ)
self.reg_head = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels, num_anchors * 5, kernel_size=3, padding=1)
)
self.num_classes = num_classes
self.num_anchors = num_anchors
def forward(self, x):
"""
:param x: FPN输出的多尺度特征 [P1, P2, P3, P4]
:return: 分类预测和回归预测(按特征尺度拼接)
"""
cls_preds = []
reg_preds = []
for feat in x:
cls = self.cls_head(feat) # (B, num_anchors*num_classes, H, W)
reg = self.reg_head(feat) # (B, num_anchors*5, H, W)
# 维度调整:(B, H*W*num_anchors, num_classes) 和 (B, H*W*num_anchors, 5)
cls = cls.permute(0, 2, 3, 1).contiguous().view(cls.shape[0], -1, self.num_classes)
reg = reg.permute(0, 2, 3, 1).contiguous().view(reg.shape[0], -1, 5)
cls_preds.append(cls)
reg_preds.append(reg)
return torch.cat(cls_preds, dim=1), torch.cat(reg_preds, dim=1)
class RemoteSensingDetector(nn.Module):
def __init__(self, num_classes):
super(RemoteSensingDetector, self).__init__()
# Backbone:ResNet50(取前4个stage的输出作为FPN输入)
self.backbone = models.resnet50(pretrained=True)
self.backbone_features = nn.ModuleList([
self.backbone.conv1, self.backbone.bn1, self.backbone.relu, # C1 (1/2)
self.backbone.maxpool, self.backbone.layer1, # C2 (1/4)
self.backbone.layer2, # C3 (1/8)
self.backbone.layer3 # C4 (1/16)
])
# FPN输入通道(ResNet50各stage输出通道)
self.fpn = FPN(in_channels_list=[256, 512, 1024, 2048], out_channels=256)
# 检测头
self.detection_head = RotatedDetectionHead(in_channels=256, num_classes=num_classes)
def forward(self, x):
# Backbone特征提取
feats = []
for layer in self.backbone_features:
x = layer(x)
if isinstance(layer, nn.Sequential): # 取layer1~layer3的输出
feats.append(x)
# FPN融合
fpn_feats = self.fpn(feats)
# 检测头预测
cls_pred, reg_pred = self.detection_head(fpn_feats)
return cls_pred, reg_pred
# -------------------------- 3. 损失函数 --------------------------
class RotatedLoss(nn.Module):
def __init__(self, cls_weight=1.0, reg_weight=5.0):
super(RotatedLoss, self).__init__()
self.cls_weight = cls_weight
self.reg_weight = reg_weight
def forward(self, cls_pred, reg_pred, labels, boxes, anchors):
"""
:param cls_pred: 分类预测 (B, N_anchors, num_classes)
:param reg_pred: 回归预测 (B, N_anchors, 5)
:param labels: 真实类别 (B, N_boxes)
:param boxes: 真实旋转框 (B, N_boxes, 5)
:param anchors: 锚框 (N_anchors, 5)
:return: 总损失
"""
# 简化版:假设已通过IOU匹配锚框与真实框,这里直接计算正样本损失
# 实际中需要先进行锚框匹配(如MaxIOU匹配)
pos_mask = ... # 正样本掩码(简化,实际需实现)
num_pos = pos_mask.sum()
# 分类损失(仅正样本)
cls_loss = F.cross_entropy(
cls_pred[pos_mask],
labels.repeat_interleave(num_pos//labels.shape[0]) # 简化,实际需对应标签
)
# 回归损失:Smooth L1(坐标+宽高) + 角度周期性损失
reg_target = self.anchor2target(boxes, anchors[pos_mask]) # 计算锚框到真实框的偏移
reg_loss = F.smooth_l1_loss(reg_pred[pos_mask, :4], reg_target[:, :4]) # 坐标+宽高损失
# 角度损失(考虑周期性:angle ∈ [-π/2, π/2],使用sin/cos转换)
angle_pred = reg_pred[pos_mask, 4]
angle_target = reg_target[:, 4]
angle_loss = 1 - torch.mean(
torch.cos(angle_pred - angle_target) # 余弦损失(角度差越小,损失越小)
)
total_loss = self.cls_weight * cls_loss + self.reg_weight * (reg_loss + angle_loss)
return total_loss
def anchor2target(self, boxes, anchors):
"""将真实框转换为相对于锚框的偏移量(简化版)"""
# 实际中需根据锚框计算dx, dy, dw, dh, dθ
return boxes - anchors # 简化,实际需更复杂的转换
# -------------------------- 4. 训练与推理示例 --------------------------
if __name__ == "__main__":
# 配置
num_classes = 5 # 假设5类目标(如建筑、车辆、机场等)
img_dir = "path/to/remote_sensing/images" # 图像路径
ann_dir = "path/to/remote_sensing/annotations" # 标注路径
batch_size = 2
epochs = 10
lr = 1e-4
# 数据集与加载器
dataset = RemoteSensingDataset(img_dir, ann_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# 模型、损失函数、优化器
model = RemoteSensingDetector(num_classes=num_classes)
criterion = RotatedLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# 简化训练循环
model.train()
for epoch in range(epochs):
total_loss = 0.0
for imgs, boxes, labels in dataloader:
optimizer.zero_grad()
cls_pred, reg_pred = model(imgs)
# 生成锚框(简化版:假设已实现锚框生成函数)
anchors = torch.randn(1000, 5) # 示例锚框,实际需根据特征图生成
loss = criterion(cls_pred, reg_pred, labels, boxes, anchors)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
# 推理示例
model.eval()
with torch.no_grad():
img, _, _ = dataset[0]
img = img.unsqueeze(0) # 加batch维度
cls_pred, reg_pred = model(img)
print("预测类别概率:", cls_pred.softmax(dim=-1)[0, :5]) # 前5个锚框的类别概率
print("预测旋转框参数:", reg_pred[0, :5]) # 前5个锚框的旋转框参数
核心实现步骤
- 数据集定义:处理遥感图像和旋转框标注(格式:[x_center, y_center, width, height, angle, class]);
- 模型结构:Backbone(ResNet50)+ Neck(FPN)+ Head(分类头+旋转框回归头);
- 损失函数:分类损失(交叉熵)+ 旋转框回归损失(Smooth L1 + 角度周期性损失);
- 训练与推理:简化的训练循环和推理逻辑。
代码说明
- 数据集:假设标注文件为txt格式,每行包含旋转框参数(中心坐标、宽高、角度)和类别,通过
RemoteSensingDataset
类加载并预处理。 - 模型:
- Backbone使用ResNet50提取多尺度特征;
- FPN融合不同分辨率特征,缓解尺度差异问题;
- 检测头预测目标类别和旋转框参数(支持任意方向)。
- 损失函数:分类损失用交叉熵,回归损失结合Smooth L1(坐标/宽高)和余弦损失(角度周期性)。
进一步优化方向
- 实现完整的锚框匹配机制(如MaxIOU)和非极大值抑制(NMS) 处理旋转框;
- 加入注意力机制(如SE模块)增强目标特征;
- 使用数据增强(如随机旋转、缩放、噪声添加)提升模型鲁棒性;
- 迁移预训练权重(如从COCO数据集迁移)加速收敛。