【MMDet Note】AnchorGenerator类的理解

发布于:2022-07-25 ⋅ 阅读:(316) ⋅ 点赞:(0)

文章目录


前言

mmdetection/mmdet/core/anchor/anchor_generator.py中AnchorGenerator类的关键代码解读。

一、总概

AnchorGenerator类主要目的是为了生成anchor-base Detector所需要的anchor_box。

通过【gen_base_anchors】方法生成单个anchor点的9种(3种尺寸、3种宽高比)基anchor_box,调用【grid_priors】方法将这9种基anchor_box再原图的尺寸上进行广播,得到一个list,其中包括所有原图尺寸上的anchor_box位置信息(左上角坐标与右下角坐标)。

二、代码解读

1. AnchorGenerator类

@PRIOR_GENERATORS.register_module()
class AnchorGenerator:
    def __init__(self,
                 strides,                       # 例:[8, 16, 32, 64, 128]
                 ratios,                        # 例:anchor的三种宽高比[0.5, 1.0, 2.0]
                 scales=None,
                 base_sizes=None,
                 scale_major=True,
                 octave_base_scale=None,        # 例:4
                 scales_per_octave=None,        # 例:3
                 centers=None,
                 center_offset=0.):
        # check center and center_offset
        if center_offset != 0:
            assert centers is None, 'center cannot be set when center_offset' \
                                    f'!=0, {centers} is given.'
        if not (0 <= center_offset <= 1):
            raise ValueError('center_offset should be in range [0, 1], '
                             f'{center_offset} is given.')
        if centers is not None:
            assert len(centers) == len(strides), \
                'The number of strides should be the same as centers, got ' \
                f'{strides} and {centers}'

        # calculate base sizes of anchors
        # self.strides = [(8,8),(16,16),(32,32),(64,64),(128,128)]
        self.strides = [_pair(stride) for stride in strides]
        # base_sizes = [8, 16, 32, 64, 128]
        self.base_sizes = [min(stride) for stride in self.strides
                           ] if base_sizes is None else base_sizes
        assert len(self.base_sizes) == len(self.strides), \
            'The number of strides should be the same as base sizes, got ' \
            f'{self.strides} and {self.base_sizes}'

        # calculate scales of anchors
        # octave_base_scale、scales_per_octave这两个参数和scales不能共存
        assert ((octave_base_scale is not None
                 and scales_per_octave is not None) ^ (scales is not None)), \
            'scales and octave_base_scale with scales_per_octave cannot' \
            ' be set at the same time'
        if scales is not None:
            self.scales = torch.Tensor(scales)
        # 通过octave_base_scale与scales_per_octave自动计算得到scale
        # self.scales = octave_base_scale * [2^0, 2^(1/3), 2^(2/3)] = [4,5,6]
        elif octave_base_scale is not None and scales_per_octave is not None:
            octave_scales = np.array(
                [2**(i / scales_per_octave) for i in range(scales_per_octave)])
            scales = octave_scales * octave_base_scale
            self.scales = torch.Tensor(scales)
        else:
            raise ValueError('Either scales or octave_base_scale with '
                             'scales_per_octave should be set')
        
        # 最终的值
        self.octave_base_scale = octave_base_scale        # 4
        self.scales_per_octave = scales_per_octave        # 3
        self.ratios = torch.Tensor(ratios)                # [0.5, 1, 2]
        self.scale_major = scale_major                    # True
        self.centers = centers                            # None
        self.center_offset = center_offset                # 0
        self.base_anchors = self.gen_base_anchors()
        # self.scales = [4,5,6]
        # self.strides = [(8,8),(16,16),(32,32),(64,64),(128,128)]

2. gen_base_anchors方法

def gen_base_anchors(self):
        """产生base_anchors,也就是单个anchor上的9种(例)不同尺寸与宽高比的anchor_box

        Returns:
            list(torch.Tensor): 每一个特征图尺寸下的基anchor_box组成的list,len(list) = len(self.stride)
        """
        multi_level_base_anchors = []               # 存储每个特征尺度下的base_anchors
        for i, base_size in enumerate(self.base_sizes):  # 在每个特征尺度下生成base_anchors
            center = None
            if self.centers is not None:
                center = self.centers[i]
            multi_level_base_anchors.append(
                # 调用gen_single_level_base_anchors方法,产生当前特征尺度下的base_anchors
                self.gen_single_level_base_anchors(         
                    base_size,                      # 8 / 16 / 32 /64 /128(for循环变量)
                    scales=self.scales,             # [4,5,6]
                    ratios=self.ratios,             # [0.5,1,2]
                    center=center))                 # None
        return multi_level_base_anchors
        # multi_level_base_anchors = [[stride1_base_anchors], [stride2_base_anchors], ...]

3. gen_single_level_base_anchors方法

def gen_single_level_base_anchors(self,
                                      base_size,        # 8 (以8为例)
                                      scales,           # [4,5,6]
                                      ratios,           # [0.5,1,2]
                                      center=None):
        """Generate base anchors of a single level.
        """
        w = base_size   # w = 8
        h = base_size   # h = 8
        if center is None:
            x_center = self.center_offset * w   # 0
            y_center = self.center_offset * h   # 0
        else:
            x_center, y_center = center

        # h_ratios:w_ratios = [0.5:1, 1:1, 2:1]
        h_ratios = torch.sqrt(ratios)        
        w_ratios = 1 / h_ratios
        
        if self.scale_major:   # self.scale_major = True
            # 由以下公式计算得到9个ws与9个hs
            ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
            hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
        else:
            ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
            hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)

        # [center_x,center_y,w,h] --> [xmin, ymin, xmax,ymax]
        base_anchors = [
            x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
            y_center + 0.5 * hs
        ]

        # 使用torch.stack改变下形状
        base_anchors = torch.stack(base_anchors, dim=-1)

        return base_anchors

        """以2个anchor为例
        base_size, scales, ratios = 8, [4,6], 1
        w, h, h_ratios, w_ratios = 8, 8, 1, 1
        ws = 8 * 1 * [4, 6] = [32, 48]
        hs = 8 * 1 * [4, 6] = [32, 48]
        base_anchors = [[-16, -24],
                        [-16, -24],
                        [16, 24],
                        [16, 24]]
        # torch.stack之后
        base_anchors = [[-16., -16.,  16.,  16.],
                        [-24., -24.,  24.,  24.]]
        """

 4. grid-priors方法

该方法与【2. gen_base_anchors方法】类似,区别是:1、该方法是后期调用使用的,而【2. gen_base_anchors方法】是在生成AnchorGenerator类时自动调用的。2、返回的列表内容不同,该方法返回每个特征图上相对于原图的所有anchor_box的位置。

def grid_priors(self, featmap_sizes, dtype=torch.float32, device='cuda'):
        """Generate grid anchors in multiple feature levels.
        """
        assert self.num_levels == len(featmap_sizes)       
        multi_level_anchors = []
        for i in range(self.num_levels):
            anchors = self.single_level_grid_priors(
                featmap_sizes[i], level_idx=i, dtype=dtype, device=device)
            multi_level_anchors.append(anchors)
        return multi_level_anchors
        # multi_level_anchors = [[level1_anchorboxs], [level2_anchorboxs], ...]

 5. single_level_grid_priors方法

    def single_level_grid_priors(self,
                                 featmap_size,
                                 level_idx,
                                 dtype=torch.float32,
                                 device='cuda'):
        """Generate grid anchors of a single level.
        也就是生成每一个特征图尺度下所有的anchor_boxs,其坐标是相对于原图尺寸的
        """
        
        # 得到当前level_idx下9个(例)不同尺寸与宽高比的base_anchors
        base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
        feat_h, feat_w = featmap_size
        stride_w, stride_h = self.strides[level_idx]

        # 以下大多为数组形状的改变,如果嫌麻烦可以直接看图中最后的输出形状即可

        # 遍历特征图上所有位置,乘上 stride,变成原图下x,y坐标
        # 例如在二维坐标下,stride = 8的特征图上的[0,0]对应原图的[0,0],特征图上的[1,1]对应原图的[8,8]
        shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
        shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h

        shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
        # shifts是base_anchors在原图尺寸上的中心点位置
        shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
        # first feat_w elements correspond to the first row of shifts
        # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
        # shifted anchors (K, A, 4), reshape to (K*A, 4)
        
        # 将base_anchors的位置信息与原图上的中心点位置相加,得到原图上anchor的位置
        all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
        all_anchors = all_anchors.view(-1, 4)

        return all_anchors

总结

本文仅代表个人理解,若有不足,欢迎批评指正。

参考:MMDet逐行解读之AnchorGenerator_武乐乐~的博客-CSDN博客


网站公告

今日签到

点亮在社区的每一天
去签到