CVPR 2025 | 调制融合模块MFM:即插即用的特征选择与细节还原利器,即插即用,涨点起飞!

发布于:2025-09-03 ⋅ 阅读:(17) ⋅ 点赞:(0)

1. 基本信息

  • 标题: Depth Information Assisted Collaborative Mutual Promotion Network for Single Image Dehazing

  • 论文来源: https://arxiv.org/pdf/2403.01105

  • 作者与单位: Yafei Zhang, Shen Zhou, Huafeng Li (昆明理工大学信息工程与自动化学院)

2. 核心创新点

  1. 双任务协同框架: 首次将单图像去雾与深度估计视为两个独立但可协同的任务,并构建了一个统一的端到端学习框架,实现二者的相互促进与性能提升。

  2. 差异感知交互机制: 设计了一种基于“差异感知”(Difference Perception)的双任务交互机制,作为连接去雾和深度估计的桥梁,实现了两个任务的联合优化。

  3. 去雾引导深度估计: 利用去雾图像与真实清晰图像之间的差异,引导深度估计网络关注去雾效果不佳的区域,从而提升深度预测的准确性。

  4. 深度辅助去雾优化: 利用去雾结果的深度图与真实深度图的差异,反向指导去雾网络关注这些区域,同时将有雾图像的深度信息作为辅助输入,提升去雾效果。

➔➔➔➔点击查看原文,获取本文及其他精选即插即用模块集合https://mp.weixin.qq.com/s/5sSrssU31MOwpEdVaeadrw

3. 方法详解

整体结构概述

该方法提出了一种名为DCMPNet的深度信息辅助协同互促网络。其核心是一个双任务(图像去雾、深度估计)协同学习框架。该框架由一个去雾网络和一个深度估计网络(DE)组成,二者通过一个核心的差异感知(DP)模块进行交互。数据流上,有雾图像首先输入去雾网络进行处理,其输出的去雾结果和中间特征,与深度估计网络的结果进行“差异感知”计算,生成损失来共同优化两个网络,形成一个闭环的相互增强系统。

步骤分解

  1. 去雾网络 (Dehazing Network):

    • 编码器: 采用U-Net结构,并集成了多个关键模块。首先通过U-Net进行初步特征提取。随后,特征流经多个局部特征嵌入的全局特征提取模块(LEGM),该模块结合了自注意力机制和卷积操作。值得注意的是,第一个LEGM会接收由深度估计网络生成的有雾图像深度图作为辅助输入。编码器各阶段的输出通过多尺度聚合注意力模块(MSAAM)进行融合,以防止浅层特征丢失。

    • 解码器: 主要由带有特征调制植入(FMI)的LEGM构成。FMI内部包含一个调制融合模块(MFM),该模块通过动态调整权重来融合来自编码器不同层级的特征,增强有效信息的表达。最终,解码器输出重建后的清晰图像 u*

  2. 深度估计网络 (Depth Estimation Network, DE):

    • 该网络采用基于扩张残差密集块(DRDB)的U-Net架构,负责从输入的图像(包括去雾结果 u* 和原始有雾图像 ũ*)中估计其深度图。

  3. 差异感知与双任务互促学习 (Difference Perception & Mutual Promotion):

    • 去雾促进深度估计: 计算去雾结果 u* 与真实图像 u 之间的差异 R(u*/u)。这个差异反映了去雾效果不佳的区域。该差异信息通过一个感知器调整深度估计网络的损失函数,使其更加关注这些区域的深度预测准确性。深度估计网络的损失函数如下: 其中 M 代表深度图,A(d,r) 是由差异 R(u*/u) 生成的权重矩阵。

    • 深度估计促进去雾: 计算去雾结果的深度图 M(u*) 与真实深度图 M(u) 之间的差异。这个差异被反馈给去雾网络的损失函数,促使去雾网络优化那些导致深度图估计不准的区域(通常也是去雾效果不佳的区域)。去雾网络的损失函数如下: 其中 A(e,r) 是由深度图差异生成的权重矩阵。同时,有雾图像的深度图 M(ũ*) 也直接作为输入注入去雾网络的编码器,为特征提取提供结构先验。

4. 即插即用模块作用

该论文的核心思想——基于差异感知的双任务协同互促机制——可以被视为一个可泛化的“即插即用”优化策略。

适用场景

  • 图像去雾: 本文的核心应用场景,尤其适用于结构复杂或雾气浓度不均的图像。

  • 其他图像复原任务: 如去雨、去噪、低光照增强等。只要能找到一个与主任务强相关、且其性能会受主任务结果质量影响的辅助任务(如深度估计、语义分割、边缘检测等),就可以借鉴此框架。

  • 下游视觉任务预处理: 可作为自动驾驶、目标检测、视频监控等系统中应对恶劣天气条件的视觉增强前端模块。

主要作用

  • 模拟“相互监督”机制: 深度估计任务充当了去雾任务的“监督员”,通过检查去雾结果的深度一致性来发现问题;反之,去雾任务也为深度估计提供了监督信号,帮助其提升在有雾图像上的鲁棒性。

  • 提升模型性能与鲁棒性: 该机制能有效利用辅助任务提供的信息,帮助主任务网络跳出局部最优解,显著提升去雾图像的结构稳定性和细节保真度。

  • 增强对困难区域的关注: 通过“差异感知”,模型能自动将计算资源和注意力集中在去雾效果不佳的区域,进行针对性优化,从而提升整体恢复质量。

  • 降低对单一模型的依赖: 避免了设计一个极其复杂的单一网络,而是通过两个相对独立的网络协同工作,共同解决一个具有挑战性的逆问题。

总结

通过“差异感知”驱动的双任务协同机制,让去雾与深度估计两大任务互相监督、彼此成就,从而突破单一任务的性能瓶颈。

➔➔➔➔点击查看原文,获取本文及其他精选即插即用模块集合https://mp.weixin.qq.com/s/5sSrssU31MOwpEdVaeadrw

5. 即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn.init import _calculate_fan_in_and_fan_out
from timm.models.layers import to_2tuple, trunc_normal_

class MFM(nn.Module):
    def __init__(self, dim, height=2, reduction=8):
        super(MFM, self).__init__()

        self.height = height
        d = max(int(dim/reduction), 4)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
            nn.Conv2d(dim, d, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(d, dim*height, 1, bias=False)
        )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, in_feats):
        B, C, H, W = in_feats[0].shape

        in_feats = torch.cat(in_feats, dim=1)
        in_feats = in_feats.view(B, self.height, C, H, W)

        feats_sum = torch.sum(in_feats, dim=1)
        attn = self.mlp(self.avg_pool(feats_sum))
        attn = self.softmax(attn.view(B, self.height, C, 1, 1))

        out = torch.sum(in_feats*attn, dim=1)
        return out
    
if __name__ == "__main__":
    # 设置输入张量大小
    batch_size = 1
    channels = 32# 输入的通道数
    height, width = 256, 256# 假设输入图像尺寸为 256*256

    # 创建输入张量列表,假设有两个特征图
    input_tensor1 = torch.randn(batch_size, channels, height, width) # 输入张量1
    input_tensor2 = torch.randn(batch_size, channels, height, width) # 输入张量2

    # 初始化 MFM 模块
    mfm = MFM(dim=channels, height=2, reduction=8)
    print(mfm)

    # 前向传播测试
    output = mfm([input_tensor1, input_tensor2])

    # 打印输入和输出的形状
    print(f"Input1 shape: {input_tensor1.shape}")
    print(f"Input2 shape: {input_tensor2.shape}")
    print(f"Output shape: {output.shape}")

网站公告

今日签到

点亮在社区的每一天
去签到