为了让大家更好地理解,我们将从零开始,逐步构建 ViT 的各个核心组件,并最终将它们组合成一个完整的模型。我们会以一个在 CIFAR-10
数据集上应用的实例来贯穿整个讲解过程。
ViT 核心思想
在讲解代码之前,我们先快速回顾一下 ViT 的核心思想,这有助于理解代码每一部分的目的。
图片切块 (Image to Patches): 传统 CNN 逐像素处理图像,而 ViT 模仿 NLP 中处理单词 (Token) 的方式。它将一幅图像 (H*W*C) 切割成一个个小块 (Patch),每个小块大小为 P*P*C。
展平与线性投射 (Patch Flattening & Linear Projection): 将每个小块展平成一个一维向量,然后通过一个全连接层(线性投射)将其映射到一个固定的维度 D,这个向量就成为了 Transformer 的 "Token"。
类别令牌 (Class Token): 模仿 BERT 的 [CLS]
令牌,在所有 Patch Token 的最前面加入一个可学习的 [CLS]
Token。这个 Token 最终将用于图像分类。
位置编码 (Positional Embedding): Transformer 本身不包含位置信息。为了让模型知道每个 Patch 的原始位置,我们需要为每个 Token(包括 [CLS]
Token)添加一个可学习的位置编码。
Transformer 编码器 (Transformer Encoder): 将带有位置编码的 Token 序列输入到标准的 Transformer Encoder 中。Encoder 由多层堆叠而成,每一层都包含一个多头自注意力模块 (Multi-Head Self-Attention) 和一个前馈网络 (Feed-Forward Network)。
分类头 (MLP Head): 将 Transformer Encoder 输出的 [CLS]
Token 对应的向量,送入一个简单的多层感知机(MLP),最终输出分类结果。
实例设定
我们将以 CIFAR-10
数据集为例。
图片尺寸 (image_size): 32*32*3
Patch 尺寸 (patch_size): 4*4 (我们可以选择 8x8 或 16x16,这里用 4x4 举例)
类别数 (num_classes): 10
嵌入维度 (dim): 512 (每个 Patch 展平后映射到的维度)
Transformer Encoder 层数 (depth): 6
多头注意力头数 (heads): 8
MLP 内部维度 (mlp_dim): 2048
根据这些设定,我们可以计算出:
每张图片的 Patch 数量 (num_patches): (32/4)x(32/4)=8x8=64
PyTorch 代码逐行实现
我们将按照 ViT 的思想,一步步构建代码。
1. Patch Embedding (图像切块与线性投射)
这是 ViT 的第一步,我们的目标是将一个 (B, C, H, W)
的图像张量,转换成一个 (B, N, D)
的 Token 序列张量,其中 B
是批量大小,N
是 Patch 数量,D
是嵌入维度。
一个巧妙高效的实现方法是使用二维卷积。
思想: 我们可以设置一个卷积层,其卷积核大小 (kernel_size) 和步长 (stride) 都等于 patch_size
。这样,卷积核每次滑动的区域恰好就是一个不重叠的 Patch。卷积的输出通道数设为我们想要的嵌入维度 dim
。
import torch
from torch import nn
class PatchEmbedding(nn.Module):
"""
将图像分割成块并进行线性嵌入。
参数:
image_size (int): 输入图像的尺寸 (假设为正方形)。
patch_size (int): 每个图像块的尺寸 (假设为正方形)。
in_channels (int): 输入图像的通道数。
dim (int): 线性投射后的嵌入维度。
"""
def __init__(self, image_size, patch_size, in_channels, dim):
super().__init__()
self.patch_size = patch_size
# 检查图像尺寸是否能被 patch 尺寸整除
if not (image_size % patch_size == 0):
raise ValueError("error")
# 计算 patch 的数量
self.num_patches = (image_size // patch_size) ** 2
# 核心:使用 Conv2d 实现 patch 化和线性投射
# kernel_size 和 stride 都设为 patch_size,实现不重叠的块分割
# out_channels 设为嵌入维度 dim
self.projection = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# 输入 x 的形状: (B, C, H, W)
# 例如: (B, 3, 32, 32)
# 经过卷积层,将图像转换为 patch 的特征图
# 输出形状: (B, dim, H/P, W/P)
# 例如: (B, 512, 8, 8)
x = self.projection(x)
# 将特征图展平
# .flatten(2) 将从第2个维度开始展平 (H/P 和 W/P 维度)
# 输出形状: (B, dim, N) 其中 N = (H/P) * (W/P)
# 例如: (B, 512, 64)
x = x.flatten(2)
# 交换维度,以匹配 Transformer 输入格式 (B, N, D)
# 输出形状: (B, N, dim)
# 例如: (B, 64, 512)
x = x.transpose(1, 2)
return x
2. Transformer Encoder Block
Transformer Encoder 由多个相同的块 (Block) 堆叠而成。每个块包含两个主要部分:
多头自注意力 (Multi-Head Self-Attention)
前馈网络 (Feed-Forward Network / MLP)
每个部分都伴随着残差连接 (Residual Connection) 和层归一化 (Layer Normalization)。
class TransformerEncoderBlock(nn.Module):
"""
标准的 Transformer Encoder 块。
参数:
dim (int): 输入的 token 维度。
heads (int): 多头注意力的头数。
mlp_dim (int): MLP 层的隐藏维度。
dropout (float): Dropout 的概率。
"""
def __init__(self, dim, heads, mlp_dim, dropout=0.1):
super().__init__()
# 第一个 LayerNorm
self.norm1 = nn.LayerNorm(dim)
# 多头自注意力模块
# PyTorch 内置的 MultiheadAttention 期望输入形状为 (N, B, D),
# 但我们通常使用 (B, N, D)。设置 batch_first=True 可以解决这个问题。
self.attention = nn.MultiheadAttention(
embed_dim=dim,
num_heads=heads,
dropout=dropout,
batch_first=True
)
# 第二个 LayerNorm
self.norm2 = nn.LayerNorm(dim)
# MLP / 前馈网络
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(), # ViT 论文中使用的激活函数
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
# x 的形状: (B, N, D)
# 1. 多头自注意力部分
# 残差连接: x + Attention(LayerNorm(x))
x_norm = self.norm1(x)
# 注意力模块返回 attn_output 和 attn_weights,我们只需要前者
attn_output, _ = self.attention(x_norm, x_norm, x_norm)
x = x + attn_output
# 2. 前馈网络部分
# 残差连接: x + MLP(LayerNorm(x))
x_norm = self.norm2(x)
mlp_output = self.mlp(x_norm)
x = x + mlp_output
return x
3. 完整的 Vision Transformer 模型
现在,我们将所有组件整合在一起。
class VisionTransformer(nn.Module):
"""
Vision Transformer 模型。
参数:
image_size (int): 输入图像尺寸。
patch_size (int): Patch 尺寸。
in_channels (int): 输入通道数。
num_classes (int): 分类类别数。
dim (int): 嵌入维度。
depth (int): Transformer Encoder 层数。
heads (int): 多头注意力头数。
mlp_dim (int): MLP 隐藏维度。
dropout (float): Dropout 概率。
"""
def __init__(self, image_size, patch_size, in_channels, num_classes,
dim, depth, heads, mlp_dim, dropout=0.1):
super().__init__()
# 1. Patch Embedding
self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)
# 计算 patch 数量
num_patches = self.patch_embedding.num_patches
# 2. Class Token
# 这是一个可学习的参数,维度为 (1, 1, D)
# '1' 个 batch,'1' 个 token,'D' 维
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 3. Positional Embedding
# 这也是一个可学习的参数
# 长度为 num_patches + 1 (为了包含 cls_token)
# 维度为 (1, N+1, D)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.dropout = nn.Dropout(dropout)
# 4. Transformer Encoder
# 使用 nn.Sequential 将多个 Encoder Block 堆叠起来
self.transformer_encoder = nn.Sequential(
*[TransformerEncoderBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)]
)
# 5. MLP Head (分类头)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim), # 在送入分类头前先进行一次 LayerNorm
nn.Linear(dim, num_classes)
)
def forward(self, img):
# img 形状: (B, C, H, W)
# 1. 获取 Patch Embedding
# x 形状: (B, N, D)
x = self.patch_embedding(img)
b, n, d = x.shape # b: batch_size, n: num_patches, d: dim
# 2. 添加 Class Token
# 将 cls_token 复制 b 份,拼接到 x 的最前面
# cls_tokens 形状: (B, 1, D)
cls_tokens = self.cls_token.expand(b, -1, -1)
# x 形状变为: (B, N+1, D)
x = torch.cat((cls_tokens, x), dim=1)
# 3. 添加 Positional Embedding
# pos_embedding 形状是 (1, N+1, D),利用广播机制直接相加
x += self.pos_embedding
x = self.dropout(x)
# 4. 通过 Transformer Encoder
# x 形状不变: (B, N+1, D)
x = self.transformer_encoder(x)
# 5. 提取 Class Token 的输出用于分类
# 只取序列的第一个 token (cls_token) 的输出
# x 形状: (B, D)
cls_token_output = x[:, 0]
# 6. 通过 MLP Head 得到最终的分类 logits
# output 形状: (B, num_classes)
output = self.mlp_head(cls_token_output)
return output
完整模型与实例
现在我们把所有代码放在一起,并用我们之前设定的 CIFAR-10 参数来实例化模型,看看它的输入和输出。
import torch
from torch import nn
# --- 组件 1: PatchEmbedding ---
class PatchEmbedding(nn.Module):
def __init__(self, image_size, patch_size, in_channels, dim):
super().__init__()
if not (image_size % patch_size == 0):
raise ValueError("Image dimensions must be divisible by the patch size.")
self.num_patches = (image_size // patch_size) ** 2
self.projection = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.projection(x)
x = x.flatten(2)
x = x.transpose(1, 2)
return x
# --- 组件 2: TransformerEncoderBlock ---
class TransformerEncoderBlock(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attention = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
attn_output, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
x = x + attn_output
mlp_output = self.mlp(self.norm2(x))
x = x + mlp_output
return x
# --- 主模型: VisionTransformer ---
class VisionTransformer(nn.Module):
def __init__(self, image_size, patch_size, in_channels, num_classes,
dim, depth, heads, mlp_dim, dropout=0.1):
super().__init__()
self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)
num_patches = self.patch_embedding.num_patches
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.dropout = nn.Dropout(dropout)
self.transformer_encoder = nn.Sequential(
*[TransformerEncoderBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)]
)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.patch_embedding(img)
b, n, d = x.shape
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer_encoder(x)
cls_token_output = x[:, 0]
output = self.mlp_head(cls_token_output)
return output
# --- 实例化并测试 ---
# CIFAR-10 实例参数
BATCH_SIZE = 4
IMAGE_SIZE = 32
IN_CHANNELS = 3
PATCH_SIZE = 4
NUM_CLASSES = 10
DIM = 512
DEPTH = 6
HEADS = 8
MLP_DIM = 2048
# 创建模型实例
vit_model = VisionTransformer(
image_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
in_channels=IN_CHANNELS,
num_classes=NUM_CLASSES,
dim=DIM,
depth=DEPTH,
heads=HEADS,
mlp_dim=MLP_DIM
)
# 创建一个假的输入图像张量 (Batch, Channels, Height, Width)
dummy_img = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
# 将图像输入模型
logits = vit_model(dummy_img)
# 打印输出的形状
print(f"输入图像形状: {dummy_img.shape}")
print(f"模型输出 (Logits) 形状: {logits.shape}")
# 检查输出形状是否正确
assert logits.shape == (BATCH_SIZE, NUM_CLASSES)
print("\n模型构建成功,输入输出形状正确!")