MAE代码粗略解读

发布于:2024-07-01 ⋅ 阅读:(124) ⋅ 点赞:(0)
class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        # --------------------------------------------------------------------------

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

    def forward_loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask


def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

github地址:mae/models_mae.py at main · facebookresearch/mae (github.com)

结合了视觉Transformer作为骨干网络的Masked Autoencoder。让我来解释一下它的构建过程和逻辑关系:

MaskedAutoencoderViT 的构建过程和逻辑:

  1. 初始化函数 __init__:

    • 在初始化函数中,定义了模型的各种参数和组件。
    • 编码器部分:
      • patch_embed: 使用 PatchEmbed 类将输入图像分割为补丁,并将其嵌入到 embed_dim 维度空间中。
      • cls_token: 类别标记,用于表示整个图像的信息。
      • pos_embed: 位置编码,采用固定的正弦余弦位置编码。
      • blocks: 使用 Block 类的模块列表,构建深度为 depth 的编码器。
      • norm: 归一化层,对编码器的输出进行归一化处理。
    • 解码器部分:
      • decoder_embed: 将编码器输出嵌入到 decoder_embed_dim 维度空间中。
      • mask_token: 用于掩码的标记。
      • decoder_pos_embed: 解码器的位置编码,与编码器相似,使用固定的正弦余弦位置编码。
      • decoder_blocks: 使用 Block 类的模块列表,构建深度为 decoder_depth 的解码器。
      • decoder_norm: 解码器输出的归一化层。
      • decoder_pred: 预测器投影层,将解码器输出映射到预测的图像补丁空间。
    • 其他参数包括 norm_pix_loss,用于指示是否对像素损失进行归一化。
  2. 初始化权重函数 initialize_weights:

    • 初始化模型的权重和层,包括位置编码、嵌入层和归一化层。
  3. 辅助函数 patchifyunpatchify:

    • patchify: 将输入的图像转换为补丁序列。
    • unpatchify: 将补丁序列转换回图像。
  4. 随机掩码函数 random_masking:

    • 对输入进行随机掩码处理,用于在训练期间模拟部分信息丢失的情况。
  5. 编码器前向传播函数 forward_encoder:

    • 接收图像输入,并在编码器中执行前向传播。
    • 将输入图像转换为嵌入补丁序列。
    • 执行随机掩码操作。
    • 将类别标记添加到序列开头。
    • 应用多个编码器块来处理嵌入的补丁序列。
  6. 解码器前向传播函数 forward_decoder:

    • 接收编码器的输出,并在解码器中执行前向传播。
    • 将编码器输出映射到解码器的嵌入维度空间。
    • 在序列中添加掩码标记。
    • 应用多个解码器块来处理解码器输入序列。
    • 使用预测器投影层将解码器输出映射回图像补丁空间。
  7. 损失计算函数 forward_loss:

    • 计算模型的损失,包括像素级损失和掩码损失。
  8. 整体前向传播函数 forward:

    • 整合编码器和解码器的前向传播过程。
    • 返回损失值、解码器的预测输出和随机掩码。

构建不同类型的 MAE 模型的函数:

  • mae_vit_base_patch16_dec512d8b: 创建一个基础的 MAE 模型,使用 MaskedAutoencoderViT 类来构建,具有较小的视觉Transformer骨干和解码器参数设置。
  • mae_vit_large_patch16_dec512d8bmae_vit_huge_patch14_dec512d8b: 创建更大的 MAE 模型,具有更大的视觉Transformer骨干,但解码器参数设置相同。

 损失介绍:

像素级损失(Pixel-wise Loss)

像素级损失是指模型预测的图像补丁与原始输入图像补丁之间的差异。在这里,模型的目标是尽可能准确地重建输入图像的各个部分,因此像素级损失衡量了模型在像素级别上的重建质量。

具体计算像素级损失的步骤如下:

  1. 图像补丁转换(Patchify)

    • 输入图像通过 patchify 函数被转换为一个补丁序列,每个补丁表示图像的一个小部分。
  2. 模型预测

    • 模型通过前向传播生成预测的图像补丁序列。
  3. 计算差异

    • 计算预测的图像补丁序列与原始图像补丁序列之间的差异,通常使用均方误差(MSE)来衡量:
    • N 是批量大小,LLL 是补丁序列的长度,predij\text{pred}_{ij}predij​ 和 targetij\text{target}_{ij}targetij​ 分别是预测和目标图像补丁序列中第 iii 个样本和第 jjj 个补丁的像素值。
  4. 均值处理

    • 对每个补丁的损失取均值,以得到每个样本的平均损失。

掩码损失(Mask Loss)

掩码损失是用于处理模型中的掩码机制,即在训练过程中模拟信息缺失或丢失的情况。在这里,模型通过随机掩码一部分输入补丁来学习如何对缺失信息进行预测,这种机制有助于提高模型的鲁棒性和泛化能力。

具体计算掩码损失的步骤如下:

  1. 掩码生成

    • 使用 random_masking 函数,对输入的补丁序列进行随机掩码,生成掩码后的补丁序列和对应的掩码(0 表示保留,1 表示移除)。
  2. 损失计算

    • 计算预测的补丁序列与原始补丁序列的像素级损失。
    • 将损失乘以掩码,以便仅考虑掩码移除的部分(即模型需要预测的部分)的损失。
    • 最终损失通过除以掩码的总和来归一化,以得到平均损失。

 


网站公告

今日签到

点亮在社区的每一天
去签到