Vision Prompt Tune(视觉提示微调)

发布于:2025-06-14 ⋅ 阅读:(23) ⋅ 点赞:(0)

方法记录

一、网络结构

可训练的参数

prompt module 

PromptedVisionTransformer(
  (transformer): PromptedTransformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2...affine=True)
    )
    (prompt_dropout): Dropout(p=0.0, inplace=False)
    (prompt_proj): Identity()
  )
  (head): Identity()
)

head 

MLP(
  (projection): Sequential()
  (last_layer): Linear(in_features=768, out_features=200, bias=True)
)

总的结构 

ViT(
  (enc): PromptedVisionTransformer(
    (transformer): PromptedTransformer(
      (embeddings): Embeddings(
        (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): Encoder(
        (layer): ModuleList(
          (0-11): 12 x Block(
            (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (ffn): Mlp(
              (fc1): Linear(in_features=768, out_features=3072, bias=True)
              (fc2): Linear(in_features=3072, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (attn): Attention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (out): Linear(in_features=768, out_features=768, bias=True)
              (attn_dropout): Dropout(p=0.0, inplace=False)
              (proj_dropout): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
          )
        )
        (encoder_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      )
      (prompt_dropout): Dropout(p=0.0, inplace=False)
      (prompt_proj): Identity()
    )
    (head): Identity()
  )
  (head): MLP(
    (projection): Sequential()
    (last_layer): Linear(in_features=768, out_features=200, bias=True)
  )
)

 这里 继承的父类

class PromptedTransformer(Transformer):

而父类里面包含了骨干网络的设置

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

Encoder中的ViT的Encoder Layer层的定义

class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

其中的embeddings就是 ViT的 patch projection layer ,而encoder 就是 12个 Encoder Layer层

二、数据集配置

1、train.json 文件的转换

只使用了CUB_200_2011数据集,需要进行转换为json文件。

转换脚本如下:

import json
import os


def process_files(label_file, image_file, class_file, output_train, output_val):
    # 读取文档1:训练/测试划分 (0表示测试集,1表示训练集)
    with open(label_file, 'r') as f:
        labels = [int(line.split()[1]) for line in f.readlines()]

    # 读取文档2:图片路径
    with open(image_file, 'r') as f:
        image_paths = [line.split()[1] for line in f.readlines()]

    # 读取文档3:类别标签
    with open(class_file, 'r') as f:
        classes = [int(line.split()[1]) for line in f.readlines()]

    # 确保三个文件的长度一致
    assert len(labels) == len(image_paths) == len(classes), "三个文件的条目数不一致"

    # 创建训练集和测试集的字典
    train_data = {}
    val_data = {}

    for i in range(len(labels)):
        img_path = image_paths[i]
        class_id = classes[i]

        if labels[i] == 1:  # 训练集
            train_data[img_path] = class_id
        else:  # 测试集
            val_data[img_path] = class_id

    # 写入JSON文件
    with open(output_train, 'w') as f:
        json.dump(train_data, f, indent=4)

    with open(output_val, 'w') as f:
        json.dump(val_data, f, indent=4)

    print(f"训练集已保存到 {output_train},包含 {len(train_data)} 个样本")
    print(f"验证集已保存到 {output_val},包含 {len(val_data)} 个样本")


# 文件路径
label_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\train_test_split.txt'   # 替换为实际文件路径
image_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\images.txt'   # 替换为实际文件路径
class_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\image_class_labels.txt' # 替换为实际文件路径
output_train = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\train.json'
output_val = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\val.json'

# 处理文件
process_files(label_file, image_file, class_file, output_train, output_val)

2、输入

inputs --- (图片)

 targets --- (类别标签)

三、训练流程 

1、prompt的处理

embedding_output = self.incorporate_prompt(x)
    def incorporate_prompt(self, x):
        # combine prompt embeddings with image-patch embeddings
        B = x.shape[0]
        # after CLS token, all before image patches
        x = self.embeddings(x)  # (batch_size, 1 + n_patches, hidden_dim)  (8,197,768)
        x = torch.cat((
                x[:, :1, :],
                self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
                x[:, 1:, :]
            ), dim=1)  # (8,202,768)
        # (batch_size, cls_token + n_prompt + n_patches, hidden_dim)

        return x

其中的类别token定义在 Embedding内部

self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

整合提示,首先对输入 进行patch Embedding,然后在拼接 [cls, prompt, x_token] 。

这里的 prompt 是长度为 5 的token,并且是可训练的参数。

self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))  # (1,5,768)  初始化为参数

2、拼接后送入ViT的连续Encoder层进行注意力运算。

 def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

得到

x = x[:, 0]

logits = self.head(x)  # 注意,这里的head是 Identity()
x = self.enc(x)  # batch_size x self.feat_dim  (8,768)

3、取 cls token温度送入head网络,预测类别token

此时的输出

x = self.head(x)  # (8,200)

4、与标签计算损失

loss = self.cls_criterion(
                    outputs, targets, self.cls_weights)

其中的loss

loss = F.cross_entropy(logits, targets, weight, reduction="none")

四、deep的变体

主要区别

prompt的设置不同

deep的初始化

            if self.prompt_config.DEEP:  # noqa False

                total_d_layer = config.transformer["num_layers"]-1
                self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
                    total_d_layer, num_tokens, prompt_dim))
                # xavier_uniform initialization
                nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)

其中

即既包括了shallow的还包括了其他11层的初始化 

    def forward_deep_prompt(self, embedding_output):
        attn_weights = []
        hidden_states = None
        weights = None
        B = embedding_output.shape[0]
        num_layers = self.vit_config.transformer["num_layers"]  # 12

        for i in range(num_layers):
            if i == 0:
                hidden_states, weights = self.encoder.layer[i](embedding_output)
            else:
                if i <= self.deep_prompt_embeddings.shape[0]:
                    deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                        self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))

                    hidden_states = torch.cat((
                        hidden_states[:, :1, :],
                        deep_prompt_emb,
                        hidden_states[:, (1+self.num_tokens):, :]
                    ), dim=1)


                hidden_states, weights = self.encoder.layer[i](hidden_states)

            if self.encoder.vis:
                attn_weights.append(weights)

        encoded = self.encoder.encoder_norm(hidden_states)
        return encoded, attn_weights

即第一层和shallow 一样的处理,不过之后的每层都会替换对应层设置的prompt 提示。