从零开始搭建CLIP模型实现基于文本的图像检索

发布于:2025-04-20 ⋅ 阅读:(42) ⋅ 点赞:(0)

CLIP原理简介

论文链接源码链接

CLIP模型由OpenAI在2021年提出,利用双Decoder(Dual Encoder)的架构来学习图像和文本之间的对应关系,是多模态大模型的开创之作,为后续许多高效的多模态模型的提出打下基础。CLIP是一个预训练模型(Pre-trained Model),在学习到图像–文本特征之间的关联后可以迁移到各种下游任务中,如图像分类,文本引导图像分割和目标检测,图像文本检索等。由于模型学习到的是文本语义和图像语义之间的关联,使得其zero-shot能力非常强大,根据论文中的描述,CLIP在很多数据集上zero-shot的结果甚至超越了许多训练好的模型的效果。CLIP的训练范式如下:

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/1d112d364a60434bba8dd07d42d2a1c6.png

CLIP的结构非常简单,数据集包含大量的图像文本对,图像经过图像编码器得到图像特征,文本经过文本编码器得到文本特征,将图像特征和文本特征按照数据集中的对应关系进行配对,不配对的特征给予惩罚,从上图中可以看出,我们希望矩阵中蓝色的值趋近于1,其余值趋近于0,采用对比学习的方式对模型进行训练,算法的伪代码如下:

在这里插入图片描述
从损失函数中可以看出,分别对特征对比矩阵的行和列进行交叉熵损失函数计算,并取平均得到最终的loss。图像编码器一般有两种选择:ResNet和ViT;文本编码器采用Transformer Encoder,均是各自领域中优秀的特征提取网络。
CLIP的推理范式如下:

在这里插入图片描述
在推理阶段,图像编码器中输入图像获取图像特征,文本编码器中输入文本获取文本特征,将图像特征向量和文本特征向量的转置相乘得到每张图像对每个文本的特征相似度,相似度最高的文本即描述了该图像中物体所属的类别。

代码实现

Flickr8k数据集下载,提取码:fbfz
DistilBert模型文件下载

我的运行环境:
CUDA 11.8
pytorch 2.2.2
transformers 4.44.0 # 用于从HuggingFace上加载预训练模型


数据集预览:
图片示例

图片示例

在这里插入图片描述

文本示例

由于作者的显卡算力有限,选取Flickr8k数据集进行模型训练,其中包含8k个图像文本对,其中一张图像对应5条文本。图像编码器采用ResNet50,直接从timm库中导入;文本编码器采用DistilBert,即轻量化的Bert模型,从HuggingFace上下载。闲话少说,小二,上菜!

### 模型参数配置 ###
import argparse
from dataclasses import dataclass

parser = argparse.ArgumentParser(description="CLIP from zero")
parser.add_argument("--image_dir", default="user/Flickr8k/Images", help='path to image folder')  # 存放图像的文件路径
parser.add_argument("--caption_dir", default="user/Flickr8k", help='path to caption folder')  # 存放文本的文件路径
parser.add_argument("--weight_dir", default='user/checkpoints', help='path to save output weight')  # 存放训练权重的文件路径
args = parser.parse_args()

@dataclass
class CLIPConfig:
    image_path: str = args.image_dir  # 图像存放路径
    image_size: int = 224  # resize后的图像尺寸,便于构建Dataloader
    caption_path: str = args.caption_dir  # 文本存放路径
    batch_size: int = 8  # 一个批次中的数据数量
    epochs: int = 3  # 训练世代

    image_encoder_model: str = "resnet50"  # 图像编码器的名称
    image_embedding_dim: int = 2048  # 图像特征的维度
    text_encoder_model: str = "distilbert-base-uncased"  # 文本编码器的名称
    text_embedding_dim: int = 768  # 文本特征的维度
    text_tokenizer: str = text_encoder_model  # 文本分词器模型的名称
    max_length: int = 200  # 文本编码器可输入的最长文本长度

    pretrained: bool = False  # 是否加载预训练好的编码器
    trainable: bool = True  # 在训练过程中是否更新编码器的参数
    temperature: float = 1.0  # 计算loss时的正则化系数

    proj_dim: int = 256  # 图像特征和文本特征统一后的维度
    dropout_rate: float = 0.1  # dropout系数,避免过拟合


### 载入数据集并初始化 ###
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
import albumentations as A
import pandas as pd
import cv2

class CLIPDataset(Dataset):
    def __init__(self, config, image_path, caption_path, transforms=True):
        """
        图片文件名和标题的长度必须相同
        如果一个图片对应多个标题,该图片文件名需要重复多次
        """
        self.image_path = image_path  # 图像路径
        self.caption_path = caption_path  # 文本路径
        self.dataframe = pd.read_csv(f"{self.caption_path}/captions.csv")  # 读取文本
        self.tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)  # 载入分词器

        self.image_filenames = self.dataframe["image"].values  # 获取图像文件名
        self.captions = list(self.dataframe["caption"].values)   # 获取图像对应的描述文本
        self.encoded_captions = self.tokenizer(self.captions, 
                                               padding=True, 
                                               truncation=True, 
                                               max_length=config.max_length)  # 文本分词
        self.transforms = transforms  # 对输入图像进行预处理

    def __getitem__(self, idx):  # 获取数据集中第idx个数据,其中包含图片名称和对应的标题(可能不止一个)
        item = {
            key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()
        }

        image = cv2.imread(f"{self.image_path}/{self.image_filenames[idx]}")  # 获取原始图像
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transforms:
            image = self.get_transforms(mode="train")(image=image)["image"]  # 对图像进行预处理

        item["image"] = torch.tensor(image).permute(2, 0, 1).float()  # 将图片转换为tensor格式,并调整为RGB顺序
        item["caption"] = self.captions[idx]  # 获取标题

        return item

    def __len__(self):
        return len(self.captions)  # 获取文本长度

    def get_transforms(self, mode="train"):
        if mode == "train":
            return A.Compose(
                [
                    A.Resize(config.image_size, config.image_size, always_apply=True),  # 对图像进行resize
                    A.Normalize(max_pixel_value=255.0, always_apply=True)  # 对像素值进行归一化
                ]
            )

### 图像编码器 ###
import torch.nn as nn
import timm

class ImageEncoder(nn.Module):
    """
    图像编码器,采用ResNet50
    """
    def __init__(self, config):
        super().__init__()
        self.model = timm.create_model(config.image_encoder_model, 
                                       pretrained=config.pretrained, 
                                       num_classes=0, global_pool="avg")  # 创建ResNet50

        for p in self.model.parameters():
            p.requires_grad = config.trainable  # 设置参数可训练

    def forward(self, x):
        image_encoded = self.model(x)  # 获得图像特征编码,形状为[batch_size, image_embedding_dim]
        return image_encoded

### 文本编码器 ###
class TextEncoder(nn.Module):
    """
    文本编码器,采用DistilBERT
    """
    def __init__(self, config):
        super().__init__()
        if config.pretrained:
            self.model = DistilBertModel.from_pretrained(config.text_encoder_model)  # 导入下载好的模型文件
        else:
            self.model = DistilBertModel(DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = config.trainable  # 设置参数可训练

        self.target_token_idx = 0
    
    # 提取出和图像对应的文本特征向量
    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        text_encoded = output.last_hidden_state[:, self.target_token_idx, :]  # [batch_size, text_embedding_dim]
        return text_encoded

### 投影层 (MLP) ###
class ProjectionHead(nn.Module):
    """
    将图像编码和文本编码映射到相同维度
    """
    def __init__(self, config, input_embedding_dim):
        super().__init__()
        self.proj = nn.Linear(input_embedding_dim, config.proj_dim)
        self.act_fn = nn.GELU()
        self.fc = nn.Linear(config.proj_dim, config.proj_dim)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.layer_norm = nn.LayerNorm(config.proj_dim)

    def forward(self, x):
        x_proj = self.proj(x)
        x = self.act_fn(x_proj)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + x_proj
        x = self.layer_norm(x)

        return x

### 定义损失函数 ###
def cross_entropy(logits, labels, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-labels * log_softmax(logits)).sum(dim=1)
    if reduction == 'mean':
        return loss.mean()
    else:
        return loss.sum()

### 模型主体 ###
import torch.nn.functional as F

class CLIP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.image_encoder = ImageEncoder(config)  # 实例化图像编码器
        self.text_encoder = TextEncoder(config)  # 实例化文本编码器
        self.image_proj = ProjectionHead(config, config.image_embedding_dim)  # 图像特征投影
        self.text_proj = ProjectionHead(config, config.text_embedding_dim)  # 文本特征投影
        self.temperature = config.temperature

    def forward(self, batch):
        image_features = self.image_encoder(batch["image"])  # 图像编码
        
        # 文本编码,tokenizer处理后的文本序列自带input_ids和attention_mask
        text_features = self.text_encoder(batch["input_ids"], batch["attention_mask"])

        image_embeddings = self.image_proj(image_features)  # 图像特征投影
        text_embeddings = self.text_proj(text_features)  # 文本特征投影

        logits = (text_embeddings @ image_embeddings.T) / self.temperature  # tensor形状为[batch_size, batch_size]
        images_similarity = image_embeddings @ image_embeddings.T  # tensor形状为[batch_size, batch_size]
        text_similarity = text_embeddings @ text_embeddings.T  # tensor形状为[batch_size, batch_size]

        # 软标签,不配对的位置设置为较小的数,而非0
        labels = F.softmax((images_similarity + text_similarity) / 2 * self.temperature, dim=-1)  
        
        loss_T = cross_entropy(logits, labels)  # 计算文本损失
        loss_I = cross_entropy(logits.T, labels.T)  # 计算图像损失
        total_loss = (loss_T + loss_I) / 2  # 对比学习平均损失

        return total_loss, logits

### 训练函数 ###
def train(model, optimizer, scheduler, train_loader, device):
    model.train()  # 模型设置为训练模式
    train_loss = 0
    train_loader = tqdm(train_loader, total=len(train_loader))  # 显示训练进度条
    cnt = 0
    for batch in train_loader:
        # print(batch.keys())
        cnt += 1
        batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}  # 将dataloader中一个batch的数据转换为字典形式
        loss, _ = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step(metrics=loss.item())  # 根据上次训练的损失更新学习率

        train_loss += loss.item()

        # 训练100个batch显示一次loss
        if cnt % 100 == 0:
            print(f' ==> Epoch: {epoch + 1}, Batch: {cnt}, Loss: {loss.item():.4f}')

    return train_loss / len(train_loader)  # 平均训练损失

### 测试函数 ###
def eval(model, val_loader, device):
    model.eval()  # 模型设置为测试模式
    val_loss = 0
    val_loader = tqdm(val_loader, total=len(val_loader))
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
            loss, _ = model(batch)
            val_loss += loss.item()

    return val_loss / len(val_loader)  # 平均测试损失

if __name__ == '__main__':
    config = CLIPConfig()  # 实例化配置信息
    model = CLIP(config)  # 实例化CLIP模型
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    # 查看模型的总参数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params / 1e6} M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.5)

    dataset = CLIPDataset(config, args.image_dir, args.caption_dir)  # 读取并预处理数据
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])  # 80%为训练数据,20%为测试数据
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

    # 开始训练
    best_loss = float("inf")
    for epoch in range(config.epochs):
        print(f"Epoch: {epoch + 1}")
        train_loss_avg = train(model, optimizer, scheduler, train_loader, device)
        val_loss_avg = eval(model, val_loader, device)

        if val_loss_avg < best_loss:
            best_loss = val_loss_avg
            torch.save(model.state_dict(), f'{args.weight_dir}' + f'/CLIP_{epoch}.pth')
            print("Best model saved!")

    # 图像文本检索推理并可视化
    # dataframe = pd.read_csv(f"{config.caption_path}/captions.csv")
    # tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
    # model.load_state_dict(torch.load(f'{args.weight_dir}' + f'/CLIP_1.pth', map_location=device))
    # model.eval()
    # 
    # image_embeddings = []
    # with torch.no_grad():
    #     for batch in tqdm(dataloader):
    #         image_features = model.image_encoder(batch["image"].to(device))  # 获取图像特征
    #         cur_image_embeddings = model.image_proj(image_features)  # [batch_size, proj_dim]  # 图像特征投影
    #         image_embeddings.append(cur_image_embeddings)  # 将一个batch的图像特征保存
    # 
    # image_embeddings = torch.cat(image_embeddings, dim=0)  # [image_number, proj_dim]
    # input_query = "two dogs sitting on the grass"  # 输入文本
    # image_filenames = dataframe["image"].values  # 待检索的图片
    # 
    # encoded_query = tokenizer([input_query])  # 对输入文本进行分词
    # batch = {key: torch.tensor(values).to(device) for key, values in encoded_query.items()}
    # 
    # with torch.no_grad():
    #     text_features = model.text_encoder(batch["input_ids"], batch["attention_mask"])  # 获取文本特征
    #     text_embeddings = model.text_proj(text_features)  # 文本特征投影,与图像特征维度一致
    # 
    # image_embeddings_n = F.normalize(image_embeddings, dim=-1)  # [image_number, proj_dim]
    # text_embeddings_n = F.normalize(text_embeddings, dim=-1)  # [1, proj_dim]
    # dot_similarity = text_embeddings_n @ image_embeddings_n.T  # 输入文本的特征和数据集中每张图像特征之间的相似度
    # 
    # values, indices = torch.topk(dot_similarity.squeeze(0), k=45)  # 获取前45个相似度最高的图像
    # matches = [image_filenames[idx] for idx in indices[::5]]  # 获取对应的图像文件名(9张图像)
    # 
    # f, axes = plt.subplots(3, 3, figsize=(10, 10))
    # f.suptitle(f"Retrieving text: {input_query}")  # 设置主标题
    # for match, ax in zip(matches, axes.flatten()):  # 显示检索出的图像
    #     image = cv2.imread(f"{args.image_dir}/{match}")
    #     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    #     ax.imshow(image)
    #     ax.axis("off")
    # 
    # plt.show()

理想结果:

在这里插入图片描述

参考链接

https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2/


网站公告

今日签到

点亮在社区的每一天
去签到