YOLO12改进-模块-引入Channel Reduction Attention (CRA)模块 降低模型复杂度,提升复杂场景下的目标定位与分类精度

发布于:2025-05-19 ⋅ 阅读:(20) ⋅ 点赞:(0)

          在语义分割任务中,基于 Transformer 的解码器需要捕获全局上下文信息,但传统自注意力机制(如 SRA)因高分辨率特征图导致计算成本高昂。现有方法多通过降低空间分辨率减少计算量,但未充分优化通道维度。为平衡全局上下文提取与计算效率,MetaSeg 提出 Channel Reduction Attention (CRA)模块,通过压缩查询(Query)和键(Key)的通道维度至一维,在保持性能的同时显著降低计算复杂度。

上面是原模型,下面是改进模型

1.  CRA介绍 

        CRA 基于多头自注意力机制,核心创新在于通道维度压缩。传统自注意力中,Query 和 Key 的通道维度通常为多维向量(如Ci​维),计算复杂度为O(N2Ci​)(N为像素数)。CRA 通过线性投影将 Query 和 Key 的通道维度压缩至一维标量(即Q,K∈RHead×Hi​Wi​×1),使查询 - 键操作的计算复杂度降至O(N2),同时通过平均池化处理值(Value)保持信息完整性。实验表明,一维标量表示仍能有效捕捉全局相似性,且计算量较传统方法减少约 50%(如对比 SRA)

       

CRA 模块结构如图所示,主要包含以下步骤:

       通道压缩:对输入特征Fi​,通过线性投影WjQ​,WjK​将 Query 和 Key 的通道维度压缩至 1,得到Qi​,Ki​;Value 通过平均池化和线性投影WjV​处理为低维特征。

       注意力计算:在一维空间中计算 Query 与 Key 的点积相似度,经 Softmax 生成注意力权重,再与 Value 相乘得到输出。

       多头融合:多头注意力结果拼接后通过线性投影WiO​输出最终特征。

 2. YOLOv12与CRA的结合      

将 CRA 模块插入 YOLO12 ,可增强模型对全局上下文的建模能力,具体作用如下:

        全局信息捕获:在目标检测中,CRA 通过一维通道压缩的自注意力机制,有效捕捉跨区域依赖关系(如物体与背景的关联),提升复杂场景下的目标定位与分类精度。

        计算效率优化:相比传统自注意力,CRA 通过通道降维显著降低计算量,在保持实时性的同时避免 YOLO12 因引入全局建模导致的推理速度下降,尤其适合高分辨率输入场景。

        多尺度特征增强:结合 YOLO12 的特征金字塔网络(FPN),CRA 可在不同尺度特征图上补充全局上下文信息,缓解小目标检测中因局部特征不足导致的漏检问题。

3. CRA代码部分

YOLOv8_improve/YOLOV12.md at master · tgf123/YOLOv8_improve · GitHub

YOLO12模型改进方法,快速发论文,总有适合你的改进,还不改进上车_哔哩哔哩_bilibili

 4. 将CRA引入到YOLOv12中

第一: 先新建一个change_model,将下面的核心代码复制到下面这个路径当中,如下图如所示。YOLOv12\ultralytics\change_model。

            ​​​​​​       

第二:在task.py中导入

 ​​​       ​​​​​​​ ​​​​​​​ ​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​       

第三:在task.py中的模型配置部分下面代码

        ​​​​​​​​​​​​​​ ​​​​​​​​​​​​​​​​​​​​​ ​​​​​​​    ​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​  

第四:将模型配置文件复制到YOLOV12.YAMY文件中

 ​​​​​​​ ​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​  

     ​​​​​​​ ​​​​​​​​​​​​​​ ​​​​​​​ ​​​​​​​​​​​​​第五:运行代码


from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld

if __name__=="__main__":

    # 使用自己的YOLOv12.yamy文件搭建模型并加载预训练权重训练模型
    model = YOLO(r"E:\Part_time_job_orders\YOLO_NEW\YOLOv12_all\ultralytics\cfg\models\12\yolo12_CRA.yaml")
        # .load(r'E:\Part_time_job_orders\YOLO_NEW\YOLOv12\yolo12n.pt')  # build from YAML and transfer weights

    results = model.train(data=r'E:\Part_time_job_orders\YOLO\YOLOv12\ultralytics\cfg\datasets\VOC_my.yaml',
                          epochs=300,
                          imgsz=640,
                          batch=64,
                          # cache = False,
                          # single_cls = False,  # 是否是单类别检测
                          # workers = 0,
                         # resume=r'D:/model/yolov8/runs/detect/train/weights/last.pt',
                          amp = True
                          )


网站公告

今日签到

点亮在社区的每一天
去签到