微调技术:Prefix-tuning vs Prompt-tuning vs P-tuning

发布于:2025-06-17 ⋅ 阅读:(22) ⋅ 点赞:(0)

这三种技术都是参数高效微调Parameter-Efficient Fine-Tuning, PEFT方法,主要用于在不微调大模型全部参数的前提下,通过较少的可训练参数实现良好的下游任务性能。它们的核心思想是:冻结预训练模型参数,仅调整少量额外模块或输入,以适配新任务。


一、理论说明

1. Prompt-tuning(离散或软提示)

基本思想:
将任务信息嵌入为一个可学习的“提示(prompt)”,并将其作为输入的一部分传入模型。

特点:

  • 引入一组可学习的嵌入向量(soft prompts),拼接在原始输入前。

  • 模型主体完全冻结,只有这些嵌入向量是可训练的。

  • 形式上可以表示为:

    输入: [P1, P2, ..., Pn, x1, x2, ..., xm]
          ⬑ prompt ⬑ 原始输入
    

优点:

  • 参数量极小(只需优化几百个向量)。
  • 可适配多个任务。

2. Prefix-tuning

基本思想:
在模型的每一层 Transformer 的 key 和 value 前加上一段可学习的 prefix 向量,用于控制注意力计算。

实现方式:

  • 在每一层添加长度为 l 的 prefix,作为 attention 的一部分。
  • 不改动 token embedding,只作用在注意力机制中。

优点:

  • 比 prompt-tuning 表现更强,尤其在复杂任务中。
  • 对注意力路径提供更细粒度控制。
  • 通常比 full fine-tuning 参数少两个数量级。

3. P-tuning(P-tuning v1 / v2)

P-tuning v1:

  • 将软提示视作可学习的嵌入向量,插入在输入 embedding 前,类似 prompt-tuning。
  • 和 prompt-tuning 差别小,更多是实验设定不同。

P-tuning v2:

  • 结合了 prefix-tuning 的思想,将 prompt 应用于所有 transformer 层,并配合 LSTM 或 MLP 学习提示生成。
  • 更适用于深层模型(如 GPT-2、T5)。

优点:

  • 适用于大规模模型和更复杂任务。
  • 兼顾了 prefix-tuning 的性能与 prompt-tuning 的轻量性。

总结对比:

技术 插入位置 可学习参数位置 参数量 表现 适用场景
Prompt-tuning 输入 token 前部 输入 embedding 前的 soft prompt 最少 较弱 简单任务
Prefix-tuning 每层注意力前缀 每层 Transformer 的 key/value 中等 较强 多层深模型
P-tuning v2 输入 + 各层前缀 LSTM/MLP生成的 prompt embedding 中等 更强 高性能需求

如用于 NLP 下游任务(如情感分析、问答等),Prefix-tuning 和 P-tuning v2 在保持模型冻结的前提下能达到甚至接近 full fine-tuning 的效果。

二、代码实例

以下是每种技术的核心代码片段(以 PyTorch + HuggingFace Transformers 框架为例),聚焦其主要实现逻辑。假设基础模型为 GPT2Model


1. Prompt-tuning(Soft Prompt)

class PromptTuningModel(nn.Module):
    def __init__(self, base_model, prompt_length=10):
        super().__init__()
        self.base_model = base_model
        self.prompt_length = prompt_length
        self.prompt_embeddings = nn.Parameter(
            torch.randn(prompt_length, base_model.config.hidden_size)
        )

    def forward(self, input_ids, attention_mask=None):
        inputs_embeds = self.base_model.transformer.wte(input_ids)  # word embeddings
        batch_size = input_ids.size(0)

        # Expand prompt to batch
        prompt = self.prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        inputs_embeds = torch.cat([prompt, inputs_embeds], dim=1)

        if attention_mask is not None:
            prompt_mask = torch.ones(batch_size, self.prompt_length).to(attention_mask.device)
            attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)

        outputs = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        return outputs

2. Prefix-tuning

class PrefixTuningModel(nn.Module):
    def __init__(self, base_model, prefix_length=5):
        super().__init__()
        self.base_model = base_model
        self.prefix_length = prefix_length
        self.num_layers = base_model.config.n_layer
        self.n_heads = base_model.config.n_head
        self.head_dim = base_model.config.hidden_size // self.n_heads

        # 每层有独立的 prefix key/value
        self.prefix_key = nn.Parameter(torch.randn(
            self.num_layers, prefix_length, self.n_heads, self.head_dim))
        self.prefix_value = nn.Parameter(torch.randn(
            self.num_layers, prefix_length, self.n_heads, self.head_dim))

    def forward(self, input_ids, attention_mask=None):
        # 将 prefix 插入每一层的 attention 中,需修改 transformer 层的 forward
        # 这里只示意插入机制,具体需修改 GPT2Attention 中的 attn 计算
        raise NotImplementedError("需在模型内部改造 attention 层逻辑")

实际使用推荐:使用 PEFT库(HuggingFace)PrefixTuningConfig


3. P-tuning v2

class PTuningV2Model(nn.Module):
    def __init__(self, base_model, prompt_length=10):
        super().__init__()
        self.base_model = base_model
        self.prompt_length = prompt_length
        self.num_layers = base_model.config.n_layer
        self.hidden_size = base_model.config.hidden_size

        # 可用 LSTM/MLP 编码 prompt
        self.prompt_encoder = nn.LSTM(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            num_layers=2,
            batch_first=True,
            bidirectional=False
        )

        self.raw_prompt = nn.Parameter(torch.randn(prompt_length, self.hidden_size))

    def forward(self, input_ids, attention_mask=None):
        batch_size = input_ids.size(0)
        prompt_input = self.raw_prompt.unsqueeze(0).expand(batch_size, -1, -1)
        encoded_prompt, _ = self.prompt_encoder(prompt_input)

        inputs_embeds = self.base_model.transformer.wte(input_ids)
        inputs_embeds = torch.cat([encoded_prompt, inputs_embeds], dim=1)

        if attention_mask is not None:
            prompt_mask = torch.ones(batch_size, self.prompt_length).to(attention_mask.device)
            attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)

        outputs = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        return outputs

4. 调用PEFT训练

注意版本:transformers==4.44.2

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import get_peft_model, PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfig, TaskType
# from dataset import load_toy_dataset
from datasets import Dataset
def load_toy_dataset():
    texts = [
        "Translate English to French: Hello",
        "Translate English to French: How are you?",
        "Translate English to French: What is your name?",
    ]
    labels = [
        "Bonjour",
        "Comment ça va?",
        "Comment tu t'appelles?",
    ]
    data = {"input": texts, "label": labels}
    return Dataset.from_dict(data)
# 可选:Prompt-Tuning / Prefix-Tuning / P-Tuning v2
PEFT_METHOD = "p-tuning"  # or "prompt-tuning" / "prefix-tuning"

def get_peft_config(method, model_name):
    if method == "prompt-tuning":
        return PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_length=10, inference_mode=False, base_model_name_or_path=model_name)
    elif method == "prefix-tuning":
        return PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10, inference_mode=False, base_model_name_or_path=model_name)
    elif method == "p-tuning":
        return PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10, encoder_hidden_size=128, base_model_name_or_path=model_name)
    else:
        raise ValueError("Unknown PEFT method")

def preprocess(example, tokenizer):
    input_ids = tokenizer(example["input"], truncation=True, padding="max_length", max_length=32, return_tensors="pt").input_ids
    labels = tokenizer(example["label"], truncation=True, padding="max_length", max_length=32, return_tensors="pt").input_ids
    example["input_ids"] = input_ids[0]
    example["labels"] = labels[0]
    return example

def main():
    model_name = "/Users/yanjp/.cache/modelscope/hub/models/Qwen/Qwen3-0_6B"

    # 加载 tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 加载模型
    base_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

    peft_config = get_peft_config(PEFT_METHOD, model_name)
    model = get_peft_model(base_model, peft_config)

    dataset = load_toy_dataset()
    dataset = dataset.map(lambda e: preprocess(e, tokenizer))
    dataset.set_format(type="torch", columns=["input_ids", "labels"])

    trainer = Trainer(
        model=model,
        args=TrainingArguments(
            output_dir="./outputs",
            per_device_train_batch_size=2,
            num_train_epochs=5,
            logging_steps=1,
            save_strategy="no"
        ),
        train_dataset=dataset,
        tokenizer=tokenizer,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
    )

    model.print_trainable_parameters()
    trainer.train()

if __name__ == "__main__":
    main()

网站公告

今日签到

点亮在社区的每一天
去签到