注意力机制——CBAM原理详解及源码解析

发布于:2024-03-28 ⋅ 阅读:(77) ⋅ 点赞:(0)

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

 

注意力机制——CBAM原理详解及源码解析

写在前面

hello,大家好,我是小苏👦🏽👦🏽👦🏽

​ 好久没有更新博客啦,今天来更一期注意力机制的文章吧!!!🌱🌱🌱在之前,我已经为大家介绍了一种最常见的注意力机制——SENET,感兴趣的可以点击☞☞☞了解详情。当然了,我也为大家介绍过目前很火的Transformer注意力机制,足足写了1W+字,自认为写的还是比较清晰,感兴趣的也可以点击☞☞☞了解更多信息。

​ 今天我也为大家介绍一种非常常见的注意力机制——CBAM注意力机制。CBAM全称为 Convolutional Block Attention Module,可以翻译成卷积块的注意模型。和之前介绍SENET一样,我会直接对CBAM的结构进行介绍,并附上关键代码,如果你想了解更多细节,可以阅读原论文,论文下载地址如下:CBAM论文📩📩📩

CBAM原理详解

​ 话不多说,直接来看一下CBAM的结构,如下图所示:

image-20230313201848619

​ 我们可以来简单的分析一下上图,首先有一个原始的特征图,即输入特征图Input Feature。接着会将输入特征图送入一个Channel Attention Module[通道注意力模块]和Spatial Attention Module[空间注意力模块],最终会得到最终的特征图Refined Feature【Refined 翻译为精致的,可以认为是得到了一个理论上更好的特征图🍵🍵🍵】。

​ 上文为大家介绍了CBAM的整体的结构,但是你现在肯定还存在诸多疑惑,不知道其具体是怎么工作的。莫慌,现在我们就来学习Channel Attention ModuleSpatial Attention Module这两个模块。

  • Channel Attention Module

论文上给出了这个图表示通道注意力模块【注意:为方便大家理解,我给论文中的图加上了特征图尺寸】,这个图其实很好理解,我们一起来看一下。首先我们会对输入特征图F分别做一个全局最大池化下采样和全局平均池化下采样,F由原来的 H × W × C H \times W \times C H×W×C变成两个 1 × 1 × C 1 \times 1 \times C 1×1×C 的特征图,接着我们会将这两个特征图送入到两个全连接层[MLP]中,最终会输出两个 1 × 1 × C 1 \times 1 \times C 1×1×C 的特征图。【注意:这里两个特征图共用这两个全连接层,这里的操作和SENET类似,不明白的可以去看看SENET。】得到两个 1 × 1 × C 1 \times 1 \times C 1×1×C的特征图后,我们将其相加并经过sigmoid激活函数将其值限制在0-1之前,这就得到了最后的Channel Attention,即上图中的 M c M_c Mc,其尺寸为 1 × 1 × C 1 \times 1 \times C 1×1×C。这个过程用公式表示如下:

image-20230313204812628

到这里,通道注意力就为大家介绍完了,接下来就要介绍空间注意了。但是大家这里需要注意一点,我们在通道注意力中得到了 1 × 1 × C 1 \times 1 \times C 1×1×C尺寸的特征图 M c M_c Mc,我们并不是将 M c M_c Mc作为输入送到空间注意力中,而是先用F和 M c M_c Mc相乘,得到特征图 F ′ \rm{F'} F F ′ \rm{F'} F的尺寸和F一致,都是 H × W × C H \times W \times C H×W×C。【这块要是不明白的大家也去看看SENET🍚🍚🍚】

  • Spatial Attention Module

    image-20230313205632671

    同样,论文中给出了这个图表示空间注意力模块,就让我们一起来探索探索叭。🌲🌲🌲首先我们知道在上一步我们得到了一个 H × W × C H \times W \times C H×W×C大小的特征图 F ′ \rm{F'} F,空间注意力首先同样会分别进行一个全局最大池化下采样和全局平均池化下采样,但是这时候我们是在channel维度上做,全局最大池化下采用会得到上图蓝色的特征图,其尺寸为 H × W × 1 H \times W \times 1 H×W×1,全局平均池化下采样会得到上图橙色的特征图,其尺寸为 H × W × 1 H \times W \times 1 H×W×1。然后我们将橙色和蓝色的特征图在channel维度进行拼接,会到底 H × W × 2 H \times W \times 2 H×W×2大小的特征图。接着会做一次卷积,将刚刚得到的 H × W × 2 H \times W \times 2 H×W×2的特征图变成 H × W × 1 H \times W \times 1 H×W×1的特征图。最后同样接一个sigmoid激活函数将特征图的值限制在0-1之前,即得到最终的 M s M_s Ms,其尺寸为 H × W × 1 H \times W \times 1 H×W×1。这个过程用公式表示如下:

    image-20230313211307154

​ 上式 f 7 × 7 f^{7 \times 7} f7×7表示卷积操作采用 7 × 7 7 \times 7 7×7的卷积核。空间注意力到这里也介绍完啦!!!🥗🥗🥗同样的,这里我们得到的 M s M_s Ms会和 F ′ \rm{F'} F相乘,得到最终的输出结果,其尺寸同样为 H × W × C H \times W \times C H×W×C

CBAM代码详解

直接上代码叭,如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv1(x))
        return out

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=1, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_att = ChannelAttention(in_planes, ratio)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_att(x) * x
        print(self.channel_att(x).shape)
        print(f"channel Attention Module:{out.shape}")
        out = self.spatial_att(out) * out
        print(self.spatial_att(out).shape)
        #print(f"Spatial Attention Module:{out.shape}")
        return out


if __name__ == '__main__':
    # Testing
    model = CBAM(3)
    input_tensor = torch.ones((1, 3, 224, 224))
    output_tensor = model(input_tensor)
    print(f'Input shape: {input_tensor.shape})')
    print(f'Output shape: {output_tensor.shape}')


我们可以来看一下结果,如下:

我简单解释下上图,第一行表示理论部分的特征图 M c M_c Mc,第二行表示 F ′ \rm{F'} F,第三行表示特征图 M s M_s Ms,第四行表示输入,第五行表示CBAM的最终输出。具体细节大家还是阅读阅读上述的源码叭,非常容易理解。

小结

​ 本节就为大家介绍到这里啦,如若有不理解的地方欢迎评论区交流讨论。🌿🌿🌿如果还有想要了解的注意力机制,也可以评论区留言喔,我会记在小本本上的,有机会就会出喔。🌾🌾🌾

 
 

如若文章对你有所帮助,那就🛴🛴🛴

         一键三连 (1).gif

本文含有隐藏内容,请 开通VIP 后查看