涨点!YOLOv8添加RT-DETR模型,提升实时目标检测精度

发布于:2024-03-15 ⋅ 阅读:(430) ⋅ 点赞:(0)

目录

                        🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

原理

主要特点

 代码实现

 yaml文件实现

完整代码分享(下载即可实现运行)


⭐欢迎大家订阅我的专栏一起学习⭐

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxv

💡魔改网络、复现论文、优化创新💡

原理

百度开发的实时检测变换器(RT-DETR)是一种尖端的端到端目标检测器,具有实时性能和高准确性。它利用Vision Transformers (ViT) 的强大功能,通过解耦内部尺度交互和跨尺度融合,高效处理多尺度特征。RT-DETR非常灵活适应各种推断速度的调整,支持使用不同的解码器层而无需重新训练。该模型在CUDA和TensorRT等加速后端上表现出色,超越了许多其他实时目标检测器。

我们首先利用主干网最后三个阶段的特征作为编码器的输入。 高效的混合编码器通过尺度内特征交互(AIFI)和跨尺度特征融合模块(CCFM)将多尺度特征转换为图像特征序列。 IoU 感知查询选择用于选择固定数量的图像特征作为解码器的初始对象查询。 最后,具有辅助预测头的解码器迭代优化对象查询以生成框和置信度分数。 

百度的RT-DETR概览 百度的RT-DETR模型架构图显示了骨干网的最后三个阶段{S3, S4, S5}作为编码器输入。高效的混合编码器通过内部尺度特征交互(AIFI)和跨尺度特征融合模块(CCFM)将多尺度特征转换为图像特征序列。采用IoU感知的查询选择来选择一定数量的图像特征作为解码器的初始对象查询。最后,解码器通过辅助预测头迭代优化对象查询,生成框和置信度得分。

主要特点
  • 高效的混合编码器: 百度的RT-DETR使用高效的混合编码器,通过解耦内部尺度交互和跨尺度融合来处理多尺度特征。这种独特的Vision Transformers架构降低了计算成本,实现实时目标检测。
  • IoU感知的查询选择: 百度的RT-DETR利用IoU感知的查询选择改进了对象查询的初始化。这使得模型能够聚焦于场景中最相关的对象,提高了检测准确性。
  • 灵活的推断速度: 百度的RT-DETR支持使用不同的解码器层灵活调整推断速度,无需重新训练。这种适应性有助于在各种实时目标检测场景中实际应用。
 代码实现
class RTDETRDecoder(nn.Module):
   
    export = False  # export mode

    def __init__(
            self,
            nc=80,
            ch=(512, 1024, 2048),
            hd=256,  # hidden dim
            nq=300,  # num queries
            ndp=4,  # num decoder points
            nh=8,  # num head
            ndl=6,  # num decoder layers
            d_ffn=1024,  # dim of feedforward
            dropout=0.,
            act=nn.ReLU(),
            eval_idx=-1,
            # Training args
            nd=100,  # num denoising
            label_noise_ratio=0.5,
            box_noise_scale=1.0,
            learnt_init_query=False):
        
        super().__init__()
        self.hidden_dim = hd
        self.nhead = nh
        self.nl = len(ch)  # num level
        self.nc = nc
        self.num_queries = nq
        self.num_decoder_layers = ndl

        # Backbone feature projection
        self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
        # NOTE: simplified version but it's not consistent with .pt weights.
        # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)

        # Transformer module
        decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
        self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)

        # Denoising part
        self.denoising_class_embed = nn.Embedding(nc, hd)
        self.num_denoising = nd
        self.label_noise_ratio = label_noise_ratio
        self.box_noise_scale = box_noise_scale

        # Decoder embedding
        self.learnt_init_query = learnt_init_query
        if learnt_init_query:
            self.tgt_embed = nn.Embedding(nq, hd)
        self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)

        # Encoder head
        self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
        self.enc_score_head = nn.Linear(hd, nc)
        self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)

        # Decoder head
        self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
        self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])

        self._reset_parameters()

    def forward(self, x, batch=None):
        """Runs the forward pass of the module, returning bounding box and classification scores for the input."""
        from ultralytics.models.utils.ops import get_cdn_group

        # Input projection and embedding
        feats, shapes = self._get_encoder_input(x)

        # Prepare denoising training
        dn_embed, dn_bbox, attn_mask, dn_meta = \
            get_cdn_group(batch,
                          self.nc,
                          self.num_queries,
                          self.denoising_class_embed.weight,
                          self.num_denoising,
                          self.label_noise_ratio,
                          self.box_noise_scale,
                          self.training)

        embed, refer_bbox, enc_bboxes, enc_scores = \
            self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)

        # Decoder
        dec_bboxes, dec_scores = self.decoder(embed,
                                              refer_bbox,
                                              feats,
                                              shapes,
                                              self.dec_bbox_head,
                                              self.dec_score_head,
                                              self.query_pos_head,
                                              attn_mask=attn_mask)
        x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
        if self.training:
            return x
        # (bs, 300, 4+nc)
        y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
        return y if self.export else (y, x)

    def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):

        anchors = []
        for i, (h, w) in enumerate(shapes):
            sy = torch.arange(end=h, dtype=dtype, device=device)
            sx = torch.arange(end=w, dtype=dtype, device=device)
            grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
            grid_xy = torch.stack([grid_x, grid_y], -1)  # (h, w, 2)

            valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
            grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH  # (1, h, w, 2)
            wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
            anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4))  # (1, h*w, 4)

        anchors = torch.cat(anchors, 1)  # (1, h*w*nl, 4)
        valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)  # 1, h*w*nl, 1
        anchors = torch.log(anchors / (1 - anchors))
        anchors = anchors.masked_fill(~valid_mask, float('inf'))
        return anchors, valid_mask

    def _get_encoder_input(self, x):
       
        # Get projection features
        x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
        # Get encoder inputs
        feats = []
        shapes = []
        for feat in x:
            h, w = feat.shape[2:]
            # [b, c, h, w] -> [b, h*w, c]
            feats.append(feat.flatten(2).permute(0, 2, 1))
            # [nl, 2]
            shapes.append([h, w])

        # [b, h*w, c]
        feats = torch.cat(feats, 1)
        return feats, shapes

    def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
        """Generates and prepares the input required for the decoder from the provided features and shapes."""
        bs = len(feats)
        # Prepare input for decoder
        anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
        features = self.enc_output(valid_mask * feats)  # bs, h*w, 256

        enc_outputs_scores = self.enc_score_head(features)  # (bs, h*w, nc)

        # Query selection
        # (bs, num_queries)
        topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
        # (bs, num_queries)
        batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)

        # (bs, num_queries, 256)
        top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
        # (bs, num_queries, 4)
        top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)

        # Dynamic anchors + static content
        refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors

        enc_bboxes = refer_bbox.sigmoid()
        if dn_bbox is not None:
            refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
        enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)

        embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
        if self.training:
            refer_bbox = refer_bbox.detach()
            if not self.learnt_init_query:
                embeddings = embeddings.detach()
        if dn_embed is not None:
            embeddings = torch.cat([dn_embed, embeddings], 1)

        return embeddings, refer_bbox, enc_bboxes, enc_scores

    # TODO
    def _reset_parameters(self):
        """Initializes or resets the parameters of the model's various components with predefined weights and biases."""
        # Class and bbox head init
        bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
        # NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
        # linear_init_(self.enc_score_head)
        constant_(self.enc_score_head.bias, bias_cls)
        constant_(self.enc_bbox_head.layers[-1].weight, 0.)
        constant_(self.enc_bbox_head.layers[-1].bias, 0.)
        for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
            # linear_init_(cls_)
            constant_(cls_.bias, bias_cls)
            constant_(reg_.layers[-1].weight, 0.)
            constant_(reg_.layers[-1].bias, 0.)

        linear_init_(self.enc_output[0])
        xavier_uniform_(self.enc_output[0].weight)
        if self.learnt_init_query:
            xavier_uniform_(self.tgt_embed.weight)
        xavier_uniform_(self.query_pos_head.layers[0].weight)
        xavier_uniform_(self.query_pos_head.layers[1].weight)
        for layer in self.input_proj:
            xavier_uniform_(layer[0].weight)

RT-DETR模型中的AIFI(基于注意力的内部尺度特征交互)模块是一个关键组件,它与CNN基于的跨尺度特征融合模块(CCFM)一起构成了模型的编码器部分。AIFI的主要思想如下->

  1. 基于注意力的特征处理:AIFI模块利用自我注意力机制来处理图像中的高级特征。自我注意力是一种机制,它允许模型在处理特定部分的数据时,同时考虑到数据的其他相关部分。这种方法特别适用于处理具有丰富语义信息的高级图像特征。
  2. 选择性特征交互:AIFI模块专注于在S5级别(即高级特征层)上进行内部尺度交互。这是基于认识到高级特征层包含更丰富的语义概念,能够更有效地捕捉图像中的概念实体间的联系。与此同时,避免在低级特征层进行相同的交互,因为低级特征缺乏必要的语义深度,且可能导致数据处理上的重复和混淆。

总结:AIFI模块的主要思想其实就是通过自我注意力机制专注于处理高级图像特征,从而提高模型在对象检测和识别方面的性能,同时减少不必要的计算消耗。

AIFI模块的主要作用和特点如下:

1. 减少计算冗余:AIFI模块进一步减少了基于变体D的计算冗余,这个变体仅在S5级别上执行内部尺度交互。

2. 高级特征的自我注意力操作:AIFI模块通过对具有丰富语义概念的高级特征应用自我注意力操作,捕捉图像中概念实体之间的联系。这种处理有助于随后的模块更有效地检测和识别图像中的对象。

3. 避免低级特征的内部尺度交互:由于低级特征缺乏语义概念,以及存在与高级特征交互时的重复和混淆风险,AIFI模块不对低级特征进行内部尺度交互。

4. 专注于S5级别:为了验证上述观点,AIFI模块仅在S5级别上进行内部尺度交互,这表明模块主要关注于处理高级特征。

 yaml文件实现

nc: 80  # number of classes
scales: # model compound scaling constants
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]]  # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]]  # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]]  # 2-P3/8
  - [-1, 6, HGBlock, [96, 512, 3]]  # stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]]  # 4-P3/16
  - [-1, 6, HGBlock, [192, 1024, 5, True, False]]  # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]  # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]]  # 8-P4/32
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]]  # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 10 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]]  # 12, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 14 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256]]  # 16, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]  # 17, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 19 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, RepC3, [256]]  # X3 (21), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]  # 22, downsample_convs.0
  - [[-1, 17], 1, Concat, [1]]  # cat Y4
  - [-1, 3, RepC3, [256]]  # F4 (24), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]  # 25, downsample_convs.1
  - [[-1, 12], 1, Concat, [1]]  # cat Y5
  - [-1, 3, RepC3, [256]]  # F5 (27), pan_blocks.1

  - [[21, 24, 27], 1, RTDETRDecoder, [nc]]  # Detect(P3, P4, P5)

大量实验表明,与其他实时检测器和类似尺寸的端到端检测器相比,RT-DETR 在速度和精度方面都实现了最先进的性能。 

完整代码分享(下载即可实现运行)

启动命令

yolo train model=rtdetr.yaml data=coco8.yaml epochs=100 imgsz=640

code

链接: https://pan.baidu.com/s/1g7WShLHPTT_sb4Hm_xLGwg?pwd=sed3 提取码: sed3 

本文含有隐藏内容,请 开通VIP 后查看