DAB-DETR: Dynamic Anchor Boxes Are Better Queries for DETR
摘要
在本文中,我们提出了一种新的查询形式,使用动态锚框作为DETR(DEtection TRansformer)的查询,并提供了对DETR中查询角色的更深入理解。这种新形式直接使用框坐标作为Transformer解码器中的查询,并逐层动态更新它们。使用框坐标不仅有助于利用显式的定位先验来提高查询与特征的相似性,消除DETR中训练收敛缓慢的问题,还允许我们使用框的宽度和高度信息来调节定位注意力图。这种设计使得DETR中的查询可以解释为逐层以级联方式执行软ROI池化。因此,它在相同设置下在MS-COCO基准测试中取得了DETR类检测模型的最佳性能,例如,使用ResNet50-DC5作为骨干网在50个epoch内训练,达到45.7%的AP。我们还进行了广泛的实验以确认我们的分析并验证我们方法的有效性。代码可在此处获取。
介绍
物体检测是计算机视觉中的一项基本任务,应用广泛。大多数经典的检测器基于卷积架构,过去十年取得了显著进展。最近,Carion等人提出了一种基于Transformer的端到端检测器DETR,它消除了对手工设计组件(如锚点)的需求,并与现代基于锚点的检测器(如Faster RCNN)相比表现出色。
与基于锚点的检测器相比,DETR将目标检测建模为集合预测问题,并使用100个可学习的查询从图像中探测和提取特征,从而无需使用非极大值抑制进行预测。然而,由于其查询设计和使用效率低下,DETR在训练收敛速度上存在显著问题,通常需要500个epoch才能取得良好表现。为了解决这个问题,许多后续工作试图改进DETR查询的设计,以加快训练收敛速度并提高性能。
尽管取得了所有这些进展,但DETR中查询的作用仍未得到充分理解和利用。大多数之前的尝试使DETR中的每个查询更明确地与一个特定的空间位置相关,而不是多个位置。然而,技术解决方案各不相同。例如,Conditional DETR通过根据内容特征适应查询来学习条件空间查询,以便更好地与图像特征匹配。Efficient DETR引入了密集预测模块来选择前K个目标查询,而Anchor DETR将查询设计为2D锚点,两者都使每个查询与特定空间位置相关。类似地,Deformable DETR直接将2D参考点视为查询,并在每个参考点执行可变形的跨注意力操作。然而,所有这些工作仅利用2D位置作为锚点,而没有考虑对象的尺度。
受这些研究的启发,我们仔细研究了Transformer解码器中的跨注意力模块,并提出使用锚框(即4D框坐标(x, y, w, h))作为DETR中的查询,并逐层更新它们。这种新的查询形式通过同时考虑每个锚框的位置和大小,为跨注意力模块引入了更好的空间先验,这也带来了更简单的实现和对DETR中查询角色的更深理解。
这种形式的关键见解是,DETR中的每个查询由两个部分组成:内容部分(解码器自注意输出)和位置部分(例如,DETR中的可学习查询)。跨注意力权重是通过将查询与一组由内容部分(编码图像特征)和位置部分(位置嵌入)组成的键进行比较计算的。因此,Transformer解码器中的查询可以解释为基于查询与特征相似性度量从特征图中提取特征,考虑了内容和位置信息。虽然内容相似性用于提取语义相关的特征,但位置相似性提供了在查询位置附近提取特征的位置约束。这种注意力计算机制激励我们将查询形式化为锚框,如图1©所示,允许我们使用锚框的中心位置(x, y)来提取中心周围的特征,并使用锚框大小(w, h)来调节跨注意力图,使其适应锚框大小。此外,由于使用了坐标作为查询,锚框可以逐层动态更新。这样,DETR中的查询可以逐层级联地执行软ROI池化。
我们通过使用锚框大小调节跨注意力,为特征提取提供了更好的定位先验。因为跨注意力可以从整个特征图中提取特征,所以为每个查询提供适当的位置先验是至关重要的,以便跨注意力模块能够专注于对应目标对象的局部区域。这也有助于加速DETR的训练收敛。大多数先前的工作通过将每个查询与特定位置相关联来改进DETR,但它们假设固定大小的各向同性高斯位置先验,这对于不同尺度的对象是不合适的。利用查询锚框中的大小信息(w, h),我们可以将高斯位置先验调节为椭圆形。具体来说,我们分别为x部分和y部分除以跨注意力权重(在softmax之前)的宽度和高度,从而帮助高斯先验更好地与不同尺度的对象匹配。为了进一步改进位置先验,我们还引入了一个温度参数来调节位置注意力的平坦度,这是所有先前工作中都被忽视的。
综上所述,我们提出的DAB-DETR(动态锚框DETR)通过直接学习锚点作为查询,提出了一种新的查询形式。这种形式提供了对查询角色的更深理解,允许我们使用锚框大小来调节Transformer解码器中的位置跨注意力图,并逐层执行动态锚框更新。我们的结果表明,DAB-DETR在相同设置下在COCO目标检测基准测试中取得了DETR类架构的最佳性能。使用单一ResNet-50模型作为骨干网训练50个epoch时,提出的方法可以达到45.7%的AP。我们还进行了广泛的实验以确认我们的分析并验证我们方法的有效性。
2 相关工作
大多数经典的检测器是基于锚点的,使用锚框(Ren等,2017;Girshick,2015;Sun等,2021)或锚点(Tian等,2019;Zhou等,2019)。相比之下,DETR(Carion等,2020)是一个完全无锚的检测器,使用一组可学习的向量作为查询。许多后续工作试图从不同角度解决DETR的慢收敛问题。Sun等(2020)指出,DETR训练缓慢的原因在于解码器中的跨注意力,因此提出了一个仅使用编码器的模型。Gao等(2021)引入了高斯先验来调节跨注意力。尽管它们提高了性能,但并没有给出DETR慢训练和查询角色的合理解释。
另一个改进DETR的方向(与我们工作更相关)是深入理解DETR中查询的角色。由于DETR中的可学习查询用于提供特征提取的定位约束,大多数相关工作尝试使DETR中的每个查询更明确地与一个特定的空间位置相关,而不是多个位置模式。例如,Deformable DETR(Zhu等,2021)直接将2D参考点视为查询,并为每个参考点预测可变形采样点以执行可变形的跨注意力操作。Conditional DETR(Meng等,2021)解耦了注意力的形成,并基于参考坐标生成位置查询。Efficient DETR(Yao等,2021)引入了一个密集预测模块,以选择前K个位置作为目标查询。虽然这些工作将查询与位置信息相关联,但它们没有一个明确的形式来使用锚点。
不同于之前假设的可学习查询向量包含框坐标信息的假设,我们的方法基于一个新的观点,即查询中包含的所有信息都是框坐标。即,锚框是更好的DETR查询。一项同时进行的工作Anchor DETR(Wang等,2021)也建议直接学习锚点,但它忽略了与其他先前工作一样的锚点宽度和高度信息。除了DETR之外,Sun等(2021)提出了一个通过直接学习框的稀疏检测器,这与我们的锚点形式相似,但它抛弃了Transformer结构,采用硬ROI对齐进行特征提取。表1总结了相关工作与我们提出的DAB-DETR之间的主要区别。我们从五个维度比较了我们的模型与相关工作:模型是否直接学习锚点,模型是否在中间阶段预测参考坐标,模型是否逐层更新参考锚点,模型是否使用标准的密集跨注意力,模型是否调节注意力以更好地匹配不同尺度的对象。更详细的DETR类模型比较见附录B。我们建议读者阅读该部分以解答关于表格的困惑。
3 为什么位置先验可以加速训练?
自注意力编码器 vs 跨注意力解码器
对DETR的训练收敛速度进行了大量研究,但缺乏统一的理解来解释这些方法为何有效。Sun等(2020)指出,DETR训练缓慢的原因主要在于解码器中的跨注意力模块,但他们只是简单地移除了解码器以加速训练。我们按照他们的分析找出跨注意力模块中哪个子模块影响了性能。将编码器中的自注意力模块与解码器中的跨注意力模块进行比较,我们发现它们输入的主要区别在于查询,如图2所示。由于解码器嵌入初始化为0,它们在第一个跨注意力模块之后投影到与图像特征相同的空间。之后,它们将在解码器层中经历与图像特征在编码器层中相似的过程。因此,问题的根源很可能在于可学习查询。
在跨注意力模块中,有两个可能的原因导致模型训练收敛缓慢:1)由于优化挑战,难以学习查询;2)可学习查询中的位置信息没有以与图像特征中使用的正弦位置编码相同的方式编码。为了验证是否是第一个原因,我们重新使用DETR中训练良好的查询(保持它们固定),仅训练其他模块。图3(a)中的训练曲线显示,固定查询仅在非常早期的epoch中(如前25个epoch)略微提高了收敛速度。因此,查询学习(或优化)可能不是主要问题。
接下来,我们转向第二种可能性,尝试找出可学习查询是否具有某些不良特性。由于可学习查询用于筛选特定区域中的目标,我们在图4(a)中可视化了一些可学习查询与图像位置嵌入之间的位置注意力图。每个查询可以看作是一个位置先验,让解码器专注于一个感兴趣区域。虽然它们充当位置约束,但它们也具有不良特性:多模式和几乎均匀的注意力权重。例如,图4(a)顶部的两个注意力图有两个或更多的集中中心,当图像中存在多个目标时,很难定位目标。图4(a)底部的图聚焦在太大或太小的区域,因此不能在特征提取过程中注入有用的位置信息。我们推测DETR查询的多模式特性可能是其训练缓慢的根本原因,并认为引入显式位置先验来约束查询在局部区域是有利的。为了验证这个假设,我们用动态锚框替换DETR中的查询形式,从而强制每个查询专注于一个特定区域,并将这种模型命名为DETR+DAB。图3(b)中的训练曲线显示,DETR+DAB在检测AP和训练/测试损失方面比DETR表现出更好的性能。请注意,DETR和DETR+DAB之间唯一的区别是查询的形式,没有引入其他技术(如300个查询或焦点损失)。这表明在解决DETR查询的多模式问题后,我们可以实现更快的训练收敛和更高的检测准确性。
一些先前的工作也进行了类似的分析并确认了这一点。例如,SMCA(Gao等,2021)通过在参考点周围应用预定义的高斯图来加速训练。Conditional DETR(Meng等,2021)使用显式位置嵌入作为位置查询进行训练,产生类似高斯核的注意力图,如图4(b)所示。尽管显式位置先验在训练中表现良好,但它们忽略了对象的尺度信息。相比之下,我们提出的DAB-DETR显式考虑了对象的尺度信息,以自适应地调整注意力权重,如图4©所示。
4 DAB-DETR
4.1 概述
按照DETR(Carion等人,2020)的模型结构,我们的模型是一个端到端的目标检测器,包括一个CNN骨干网络、Transformer编码器和解码器,以及用于框和标签的预测头。我们主要改进了解码器部分,如图5所示。
给定一幅图像,我们使用CNN骨干网络提取图像的空间特征,然后通过Transformer编码器细化这些特征。然后,位置查询(锚框)和内容查询(解码器嵌入)一起输入到解码器中,以探测与这些锚点相对应并与内容查询具有相似模式的对象。双重查询逐层更新,逐渐接近目标真实对象。最终解码器层的输出用于通过预测头预测带标签的对象,然后进行二分图匹配来计算损失,如同在DETR中进行的那样。
为了说明我们动态锚框的通用性,我们还设计了一个更强的DAB-Deformable-DETR,具体内容在附录中提供。
4.2 直接学习锚框
正如在第1节中讨论的DETR中查询的角色,我们提出直接学习查询框或称为锚框,并从这些锚点导出位置查询。每个解码器层中有两个注意力模块,包括一个自注意力模块和一个跨注意力模块,分别用于查询更新和特征探测。每个模块都需要查询、键和值来执行基于注意力的值聚合,但这些三元组的输入有所不同。
我们将第 q q q个锚点表示为 A q = ( x q , y q , w q , h q ) A_q = (x_q, y_q, w_q, h_q) Aq=(xq,yq,wq,hq),其中 x q , y q , w q , h q ∈ R x_q, y_q, w_q, h_q ∈ R xq,yq,wq,hq∈R, C q ∈ R D C_q ∈ R^D Cq∈RD和 P q ∈ R D P_q ∈ R^D Pq∈RD分别表示其对应的内容查询和位置查询,其中 D D D是解码器嵌入和位置查询的维度。
给定一个锚点 A q A_q Aq,其位置查询 P q P_q Pq由下式生成:
P q = M L P ( P E ( A q ) ) P_q = MLP(PE(A_q)) Pq=MLP(PE(Aq))
其中 P E PE PE表示从浮点数生成正弦嵌入的位置信息编码,MLP的参数在所有层中共享。由于 A q A_q Aq是四元数,我们在这里重载了 P E PE PE操作符:
P E ( A q ) = P E ( x q , y q , w q , h q ) = C a t ( P E ( x q ) , P E ( y q ) , P E ( w q ) , P E ( h q ) ) PE(A_q) = PE(x_q, y_q, w_q, h_q) = Cat(PE(x_q), PE(y_q), PE(w_q), PE(h_q)) PE(Aq)=PE(xq,yq,wq,hq)=Cat(PE(xq),PE(yq),PE(wq),PE(hq))
符号 C a t Cat Cat表示连接函数。在我们的实现中,位置编码函数 P E PE PE将一个浮点数映射到具有 D / 2 D/2 D/2维的向量: P E : R → R D / 2 PE: R → R^{D/2} PE:R→RD/2。因此,MLP函数将 2 D 2D 2D维向量投影到 D D D维空间: M L P : R 2 D → R D MLP: R^{2D} → R^D MLP:R2D→RD。MLP模块有两个子模块,每个子模块由一个线性层和一个 R e L U ReLU ReLU激活组成,并在第一个线性层中进行特征降维。
在自注意力模块中,查询、键和值都有相同的内容项,但查询和键包含额外的位置项:
S e l f − A t t n : Q q = C q + P q , K q = C q + P q , V q = C q Self-Attn: Q_q = C_q + P_q, K_q = C_q + P_q, V_q = C_q Self−Attn:Qq=Cq+Pq,Kq=Cq+Pq,Vq=Cq
受 C o n d i t i o n a l D E T R Conditional\ DETR Conditional DETR(Meng等人,2021)的启发,我们在跨注意力模块中将位置和内容信息一起连接为查询和键,以便我们可以解耦内容和位置对查询与特征相似性度量的贡献,该度量是通过查询和键之间的点积计算的。为了重新调整位置嵌入,我们还利用了条件空间查询(Meng等人,2021)。更具体地说,我们学习了一个 M L P ( c s q ) : R D → R D MLP(csq): R^D → R^D MLP(csq):RD→RD,以获得一个基于内容信息的尺度向量,并使用它对位置嵌入进行元素级乘法:
C r o s s − A t t n : Q q = C a t ( C q , P E ( x q , y q ) ⋅ M L P ( c s q ) ( C q ) ) , K x , y = C a t ( F x , y , P E ( x , y ) ) , V x , y = F x , y Cross-Attn: Q_q = Cat(C_q, PE(x_q, y_q) \cdot MLP(csq)(C_q)), K_{x,y} = Cat(F_{x,y}, PE(x, y)), V_{x,y} = F_{x,y} Cross−Attn:Qq=Cat(Cq,PE(xq,yq)⋅MLP(csq)(Cq)),Kx,y=Cat(Fx,y,PE(x,y)),Vx,y=Fx,y
其中 F x , y ∈ R D F_{x,y} ∈ R^D Fx,y∈RD是位置 ( x , y ) (x, y) (x,y)处的图像特征, ⋅ \cdot ⋅表示元素级乘法。查询和键中的位置嵌入都基于 2 D 2D 2D坐标生成,使其与之前的工作(Meng等人,2021;Wang等人,2021)中相似位置的比较更加一致。
4.3 锚框更新
将坐标作为查询进行学习使得可以逐层更新它们。相比之下,对于高维嵌入的查询(如 D E T R DETR DETR和 C o n d i t i o n a l D E T R Conditional\ DETR Conditional DETR),很难逐层进行查询细化,因为不清楚如何将更新后的锚框转换回高维查询嵌入。
按照先前的做法(Zhu等人,2021;Wang等人,2021),我们在每层中更新锚点,并通过预测头预测相对位置 ( Δ x , Δ y , Δ w , Δ h ) (Δx, Δy, Δw, Δh) (Δx,Δy,Δw,Δh),如图5所示。请注意,不同层中的所有预测头共享相同的参数。
4.4 宽度和高度调制的高斯核
传统的位置注意力图使用高斯状的先验,如图6左所示。但这种先验对于所有对象来说都是各向同性且固定大小的,忽略了对象的尺度信息。为了改进位置先验,我们建议在注意力图中注入尺度信息。
在原始位置注意力图中,查询到键的相似性计算为两个坐标编码点积之和:
Attn ( ( x , y ) , ( x r e f , y r e f ) ) = P E ( x ) ⋅ P E ( x r e f ) + P E ( y ) ⋅ P E ( y r e f ) D \text{Attn}((x, y), (xref, yref)) = \frac{PE(x) \cdot PE(xref) + PE(y) \cdot PE(yref)}{\sqrt{D}} Attn((x,y),(xref,yref))=DPE(x)⋅PE(xref)+PE(y)⋅PE(yref)
其中 1 D \frac{1}{\sqrt{D}} D1 用于按Vaswani等(2017)建议重新调整值。我们通过将相对锚点的宽度和高度分别从 x x x部分和 y y y部分除以来调节位置注意力图(在softmax之前),以平滑高斯先验,使其更好地与不同尺度的对象匹配:
ModulateAttn ( ( x , y ) , ( x r e f , y r e f ) ) = P E ( x ) ⋅ P E ( x r e f ) w q + P E ( y ) ⋅ P E ( y r e f ) h q \text{ModulateAttn}((x, y), (xref, yref)) = \frac{PE(x) \cdot PE(xref)}{w_q} + \frac{PE(y) \cdot PE(yref)}{h_q} ModulateAttn((x,y),(xref,yref))=wqPE(x)⋅PE(xref)+hqPE(y)⋅PE(yref)
其中 w q w_q wq和 h q h_q hq分别是锚点 A q A_q Aq的宽度和高度, w q , r e f w_q,ref wq,ref和 h q , r e f h_q,ref hq,ref是通过下式计算的参考宽度和高度:
w q , r e f , h q , r e f = σ ( M L P ( C q ) ) w_q,ref, h_q,ref = \sigma(MLP(C_q)) wq,ref,hq,ref=σ(MLP(Cq))
这种调制的位置注意力帮助我们提取不同宽度和高度的对象特征,图6展示了调制后的注意力的可视化结果。
4.5 温度调节
对于位置编码,我们使用正弦函数(Vaswani等,2017),其定义为:
P E ( x ) 2 i = sin ( x T 2 i / D ) , P E ( x ) 2 i + 1 = cos ( x T 2 i / D ) PE(x)_{2i} = \sin\left(\frac{x}{T^{2i/D}}\right), PE(x)_{2i+1} = \cos\left(\frac{x}{T^{2i/D}}\right) PE(x)2i=sin(T2i/Dx),PE(x)2i+1=cos(T2i/Dx)
其中 T T T是手动设计的温度, 2 i 2i 2i和 2 i + 1 2i+1 2i+1分别表示编码向量中的索引。温度 T T T在式(8)中影响位置先验的大小,如图7所示。较大的 T T T会导致更平坦的注意力图,反之亦然。请注意,温度 T T T在(Vaswani等,2017)中被硬编码为 10000 10000 10000,用于自然语言处理,其中 x x x的值是表示每个单词在句子中位置的整数。然而,在 D E T R DETR DETR中, x x x的值是介于 0 0 0和 1 1 1之间的浮点数,表示边界框的坐标。因此,对于视觉任务,急需不同的温度。在这项工作中,我们经验性地选择 T = 20 T=20 T=20用于我们所有的模型。