基于 LoRA的广义知识蒸馏(GKD)训练
flyfish
通过参数高效的 LoRA(低秩适应)技术,结合广义知识蒸馏(GKD)方法,让小尺寸的学生模型(如 Qwen2-0.5B-Instruct)高效学习大尺寸教师模型(如 Qwen2-1.5B-Instruct)的知识和能力,最终在减少计算资源消耗的前提下,提升小模型的对话性能,使其接近大模型的水平。
python examples/scripts/gkd.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \ # 学生模型:小模型,待蒸馏的模型
--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ # 教师模型:大模型,提供知识的模型
--dataset_name trl-lib/chatbot_arena_completions \ # 训练数据集:对话竞技场数据(含高质量对话)
--learning_rate 2e-4 \ # 学习率:LoRA微调通常用更大的学习率(全量训练一般2e-5)
--per_device_train_batch_size 4 \ # 单设备训练批次大小
--gradient_accumulation_steps 8 \ # 梯度累积步数:总批次=4*8=32(节省显存)
--output_dir gkd-model \ # 模型保存路径
--num_train_epochs 1 \ # 训练轮次
--push_to_hub \ # 训练后推送到Hugging Face Hub
--gradient_checkpointing \ # 启用梯度检查点:牺牲少量速度换显存
--use_peft \ # 启用PEFT(参数高效微调)框架,这里用于LoRA
--lora_r 64 \ # LoRA的秩(秩越低,参数量越少)
--lora_alpha 16 # LoRA的缩放系数(控制更新幅度)
解析参数 → 加载模型 / Tokenizer → 加载数据集 → 初始化 GKD 训练器 → 执行训练 → 保存模型。
# 导入必要的库
# 加载数据集的工具
from datasets import load_dataset
# 加载分词器和生成配置的工具
from transformers import AutoTokenizer, GenerationConfig
# 从trl库导入GKD相关的配置、训练器和工具
from trl import (
GKDConfig, # GKD训练的核心配置类
GKDTrainer, # GKD训练器,用于实现广义知识蒸馏
LogCompletionsCallback, # 记录生成结果的回调函数,用于评估
ModelConfig, # 模型相关配置(如LoRA参数、量化设置等)
ScriptArguments, # 脚本级参数(如数据集路径、分裂等)
TrlParser, # trl库专用的参数解析器
get_kbit_device_map, # 获取量化模型的设备映射(自动分配GPU/CPU)
get_peft_config, # 获取PEFT配置(如LoRA参数)
get_quantization_config, # 获取量化配置(如4/8位量化)
)
if __name__ == "__main__":
# 初始化参数解析器,支持解析三类配置:脚本参数、GKD训练配置、模型配置
parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig))
# 解析命令行参数,得到三个配置对象
# script_args:数据集路径、分裂等脚本级参数
# training_args:GKD训练的核心参数(学习率、批次大小等)
# model_args:模型相关参数(LoRA配置、量化设置等)
script_args, training_args, model_args = parser.parse_args_and_config()
################
# 模型与分词器配置
################
# 根据model_args获取量化配置(如4位/8位量化),用于减少显存占用
quantization_config = get_quantization_config(model_args)
# 定义学生模型的初始化参数
model_kwargs = dict(
revision=model_args.model_revision, # 模型版本(如特定commit哈希)
trust_remote_code=model_args.trust_remote_code, # 是否信任模型的自定义代码(如非标准架构)
attn_implementation=model_args.attn_implementation, # 注意力实现方式(如flash attention加速)
torch_dtype=model_args.torch_dtype, # 数据类型(如float16/bfloat16,节省显存)
# 启用梯度检查点时禁用缓存(两者冲突),否则启用缓存加速
use_cache=False if training_args.gradient_checkpointing else True,
# 量化时自动分配设备(GPU/CPU),非量化时不指定
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config, # 量化配置(如4位量化参数)
)
# 将学生模型参数传递给训练配置
training_args.model_init_kwargs = model_kwargs
# 定义教师模型的初始化参数(与学生模型类似,但有细微差别)
teacher_model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=True, # 教师模型仅用于推理,启用缓存加速生成
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
# 将教师模型参数传递给训练配置
training_args.teacher_model_init_kwargs = teacher_model_kwargs
# 加载分词器(与学生模型匹配,确保格式一致)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, # 分词器路径(与学生模型相同)
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
padding_side="left", # 左填充(生成任务常用,避免右填充影响生成逻辑)
)
# 若分词器未定义pad_token,用eos_token代替(确保填充功能正常)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# 数据集加载
################
# 加载指定数据集(如trl-lib/chatbot_arena_completions对话数据集)
# script_args.dataset_name:数据集名称,script_args.dataset_config:数据集配置(如子数据集)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# 训练初始化
################
# 初始化GKD训练器,核心组件
trainer = GKDTrainer(
model=model_args.model_name_or_path, # 学生模型路径(如Qwen/Qwen2-0.5B-Instruct)
teacher_model=training_args.teacher_model_name_or_path, # 教师模型路径(如Qwen/Qwen2-1.5B-Instruct)
args=training_args, # 训练配置(学习率、批次大小等)
train_dataset=dataset[script_args.dataset_train_split], # 训练集(如dataset["train"])
# 验证集(若启用评估)
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer, # 用于数据预处理的分词器
peft_config=get_peft_config(model_args), # PEFT配置(如LoRA参数:r=64, alpha=16)
)
# 若启用评估策略(如每轮评估),配置生成参数并添加回调
if training_args.eval_strategy != "no":
# 定义生成配置(控制模型生成行为)
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, # 最大生成长度
do_sample=True, # 启用采样(而非贪心生成)
temperature=training_args.temperature # 温度参数(控制生成多样性,值越大越随机)
)
# 初始化回调函数:记录评估时的生成结果(如保存8个示例)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
# 向训练器添加回调
trainer.add_callback(completions_callback)
# 启动训练
trainer.train()
# 保存模型到输出目录
trainer.save_model(training_args.output_dir)
# 若启用push_to_hub,将模型推送到Hugging Face Hub
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
GKDTrainer
import os
import random
import textwrap
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset
from transformers import (
AutoModelForCausalLM, # 用于加载因果语言模型(如GPT类模型)
BaseImageProcessor, # 图像处理基类(此处未直接使用,为兼容多模态预留)
DataCollator, # 数据整理器,用于批量处理数据
FeatureExtractionMixin, # 特征提取混入类(兼容多模态)
GenerationConfig, # 生成配置,控制模型生成行为(如长度、温度等)
PreTrainedModel, # 预训练模型基类
PreTrainedTokenizerBase, # 预训练分词器基类
ProcessorMixin, # 处理器混入类(兼容多模态处理器)
is_wandb_available, # 检查是否安装wandb(实验跟踪工具)
)
from transformers.trainer_callback import TrainerCallback # 训练回调基类
from transformers.trainer_utils import EvalPrediction # 评估预测结果格式
from transformers.utils import is_peft_available # 检查是否安装PEFT(参数高效微调工具)
from ..models import prepare_deepspeed # 准备Deepspeed配置(分布式训练)
from ..models.utils import unwrap_model_for_generation # 为生成任务解包模型(如处理PEFT包装)
from .gkd_config import GKDConfig # GKD训练的核心配置类
from .sft_trainer import SFTTrainer # 监督微调训练器(GKDTrainer的父类)
from .utils import (
DataCollatorForChatML, # 针对ChatML格式的数据集整理器
disable_dropout_in_model, # 禁用模型中的dropout层(稳定训练)
empty_cache, # 清空GPU缓存(节省显存)
generate_model_card, # 生成模型卡片(README.md)
get_comet_experiment_url, # 获取Comet实验跟踪URL(若使用)
)
# 条件导入:仅当PEFT库可用时导入PeftConfig
if is_peft_available():
from peft import PeftConfig
# 条件导入:仅当wandb可用时导入wandb(实验跟踪)
if is_wandb_available():
import wandb
class GKDTrainer(SFTTrainer):
"""
广义知识蒸馏(Generalized Knowledge Distillation)训练器,继承自监督微调训练器(SFTTrainer)。
核心功能:通过教师模型指导学生模型训练,结合动态生成样本(on-policy学习)和广义JSD损失,提升小模型性能。
"""
_tag_names = ["trl", "gkd"] # 模型卡片标签
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, # 学生模型(可传入路径或实例)
teacher_model: Union[PreTrainedModel, nn.Module, str] = None, # 教师模型(可传入路径或实例)
args: Optional[GKDConfig] = None, # GKD训练配置(含蒸馏参数、训练超参等)
data_collator: Optional[DataCollator] = None, # 数据整理器(默认为ChatML格式)
train_dataset: Optional[Dataset] = None, # 训练数据集
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, # 评估数据集
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None, # 数据处理器(通常为分词器)
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, # 评估指标计算函数
callbacks: Optional[list[TrainerCallback]] = None, # 训练回调(如日志、早停等)
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), # 优化器和学习率调度器
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, # 处理logits用于计算指标
peft_config: Optional["PeftConfig"] = None, # PEFT配置(如LoRA参数)
formatting_func: Optional[Callable] = None, # 数据格式化函数(将样本转为模型输入格式)
):
# 禁用自动移除未使用的列(因GKD需要"prompts"等额外字段)
args.remove_unused_columns = False
# 初始化数据整理器:使用ChatML格式(适合对话模型),限制最大长度
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
# 调用父类(SFTTrainer)的初始化方法,完成基础训练器配置
super().__init__(
model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
peft_config=peft_config,
formatting_func=formatting_func,
)
# 处理教师模型的初始化参数
if args.teacher_model_init_kwargs is None:
teacher_model_init_kwargs = {} # 无参数时使用空字典
elif not isinstance(teacher_model, str):
# 若教师模型已实例化,则不允许传入初始化参数(避免冲突)
raise ValueError(
"已传入实例化的teacher_model,但同时指定了teacher_model_init_kwargs,两者冲突。"
)
else:
teacher_model_init_kwargs = args.teacher_model_init_kwargs
# 处理数据类型参数(将字符串转为torch dtype,如"float16"→torch.float16)
teacher_model_init_kwargs["torch_dtype"] = (
teacher_model_init_kwargs["torch_dtype"]
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
)
# 若教师模型是路径字符串,则加载预训练模型
if isinstance(teacher_model, str):
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
# 禁用学生模型的dropout层(稳定蒸馏过程,减少随机性)
if args.disable_dropout:
disable_dropout_in_model(self.model)
# 准备教师模型:若启用Deepspeed(分布式训练),则适配Deepspeed;否则用accelerator准备
if self.is_deepspeed_enabled:
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
else:
# 将教师模型设为评估模式(不训练,仅用于推理)
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
# 初始化GKD核心超参数
self.lmbda = args.lmbda # 学生自生成样本的概率(on-policy学习概率)
self.beta = args.beta # 广义JSD损失的插值系数
self.temperature = args.temperature # 概率分布的温度系数(控制平滑度)
self.seq_kd = args.seq_kd # 是否强制使用教师生成的序列进行蒸馏
# 初始化生成配置(控制动态样本生成的行为)
self.generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens, # 最大生成token数
temperature=args.temperature, # 生成温度(值越大越随机)
do_sample=True, # 启用采样(而非贪心生成)
top_k=0, # 不限制top_k(配合温度控制多样性)
use_cache=False if args.gradient_checkpointing else True, # 梯度检查点启用时禁用缓存(冲突)
pad_token_id=self.processing_class.pad_token_id, # 填充token ID
)
# 适配模型自定义的EOS token(如Llama 3的<|eot_id|>)
if (
hasattr(self.model.generation_config, "eos_token_id")
and self.model.generation_config.eos_token_id is not None
):
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
@staticmethod
def generalized_jsd_loss(
student_logits, # 学生模型的logits,形状:(batch_size, seq_len, vocab_size)
teacher_logits, # 教师模型的logits,形状同上
labels=None, # 标签,形状:(batch_size, seq_len),-100表示padding(忽略)
beta=0.5, # 插值系数(控制教师/学生分布权重)
temperature=1.0, # 温度系数(软化概率分布)
reduction="batchmean", # 损失聚合方式(batchmean/sum/mean)
):
"""
计算广义Jensen-Shannon散度(JSD)损失,用于知识蒸馏。
参考论文:https://huggingface.co/papers/2306.13649 公式(1)
"""
# 温度缩放:软化概率分布(温度越高,分布越平滑)
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
# 计算学生和教师的对数概率(log_softmax)
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
if beta == 0:
# beta=0:退化为传统KL散度(学生模仿教师)
# F.kl_div(input=学生对数概率, target=教师对数概率, log_target=True)
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
elif beta == 1:
# beta=1:反向KL散度(教师模仿学生,适合学生容量较小时避免模式崩溃)
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
else:
# 混合分布的对数概率:log[(1-beta)*P_student + beta*P_teacher]
# 等价于log(exp(log(1-beta) + log_P_student) + exp(log(beta) + log_P_teacher))
beta = torch.tensor(beta, dtype=student_log_probs.dtype) # 转为tensor(匹配设备和类型)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
dim=0, # 按第0维(学生/教师)求和
)
# 计算混合分布与教师/学生分布的KL散度(注意PyTorch的KL顺序与数学定义相反)
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
# 广义JSD:beta*KL(混合||教师) + (1-beta)*KL(混合||学生)
jsd = beta * kl_teacher + (1 - beta) * kl_student
# 掩码处理:忽略padding位置(labels=-100)的损失
if labels is not None:
mask = labels != -100 # 有效位置为True,padding为False
jsd = jsd[mask] # 只保留有效位置的损失
# 损失聚合(根据reduction参数)
if reduction == "batchmean":
# 按有效样本数平均(避免padding影响)
if labels is not None:
return jsd.sum() / mask.sum() # 总损失 / 有效token数
else:
return jsd.sum() / (jsd.size(0) * jsd.size(1)) # 总损失 / 总token数(无标签时)
elif reduction == "sum":
return jsd.sum() # 求和
elif reduction == "mean":
return jsd.mean() # 简单平均
else:
return jsd # 不聚合,返回原始损失 tensor
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
计算GKD的损失:通过广义JSD损失让学生模仿教师的输出分布。
"""
# 学生模型前向传播(获取logits)
outputs_student = model(
input_ids=inputs["input_ids"], # 输入token ID
attention_mask=inputs["attention_mask"], # 注意力掩码(0表示padding)
)
# 教师模型前向传播(评估模式,不计算梯度)
self.teacher_model.eval() # 确保教师模型处于评估模式(禁用dropout等)
with torch.no_grad(): # 禁用梯度计算(节省显存)
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# 切片处理:只保留生成部分的logits(排除输入prompt部分)
prompt_lengths = inputs["prompts"].shape[1] # prompt的长度(输入部分,无需预测)
# 学生logits:从prompt结束位置的前一个token开始,到序列结束前一个token(因语言模型预测下一个token)
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
# 教师logits:同上(与学生对齐)
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
# 标签:从prompt结束位置开始(生成部分的真实标签)
shifted_labels = inputs["labels"][:, prompt_lengths:]
# 计算广义JSD损失
loss = self.generalized_jsd_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
labels=shifted_labels, # 用于掩码padding
beta=self.beta, # 从初始化参数获取
)
# 清空GPU缓存(节省显存)
empty_cache()
# 返回损失(可选返回学生模型输出)
return (loss, outputs_student) if return_outputs else loss
@staticmethod
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
"""
生成动态样本(on-policy学习用):基于输入prompt生成输出序列,作为新的训练数据。
"""
# 基于prompt生成输出(仅用prompt作为输入,不包含原始标签)
generated_outputs = model.generate(
input_ids=inputs["prompts"], # 输入prompt的token ID
attention_mask=inputs.get("prompt_attention_mask", None), # prompt的注意力掩码
generation_config=generation_config, # 生成配置(长度、温度等)
return_dict_in_generate=True, # 返回详细生成结果(含序列、分数等)
)
# 获取生成的token ID序列
generated_tokens = generated_outputs.sequences
# 初始化新的注意力掩码(全1,后续修正padding位置)
new_attention_mask = torch.ones_like(generated_tokens)
# 新标签:复制生成的token(后续修正padding位置)
new_labels = generated_tokens.clone()
# 处理padding token(若指定)
if pad_token_id is not None:
# 标签中padding位置设为-100(忽略损失)
new_labels[new_labels == pad_token_id] = -100
# 注意力掩码中padding位置设为0(不参与注意力计算)
new_attention_mask[generated_tokens == pad_token_id] = 0
# 返回生成的输入ID、注意力掩码、标签
return generated_tokens, new_attention_mask, new_labels
def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
"""
单步训练:实现on-policy学习,动态生成样本用于训练。
逻辑:以概率lmbda使用学生自生成样本,或强制使用教师生成样本(seq_kd=True)。
"""
if self.seq_kd:
# seq_kd=True:强制使用教师模型生成样本(适合初始训练阶段,学习教师的"正确"输出)
# 解包教师模型(处理PEFT/分布式包装)
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
# 生成教师样本
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
)
# 更新输入为教师生成的样本
inputs["input_ids"] = new_input_ids
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels
# 以概率lmbda使用学生自生成样本(on-policy学习核心)
if random.random() <= self.lmbda:
# 解包学生模型(处理PEFT/分布式包装)
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
# 生成学生样本
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
)
# 更新输入为学生生成的样本(让学生从自身生成的结果中学习)
inputs["input_ids"] = new_input_ids
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels
# 调用父类的training_step计算损失并更新参数
loss = super().training_step(model, inputs, num_items_in_batch)
return loss
def create_model_card(
self,
model_name: Optional[str] = None, # 模型名称
dataset_name: Optional[str] = None, # 训练数据集名称
tags: Union[str, list[str], None] = None, # 模型标签
):
"""
生成模型卡片(README.md),包含训练信息、引用、标签等,方便上传到Hugging Face Hub。
"""
# 仅在主进程执行(避免多进程重复生成)
if not self.is_world_process_zero():
return
# 确定基座模型名称(若模型从预训练模型微调而来)
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# 标准化标签(转为集合避免重复)
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
# 若使用unsloth加速训练,添加对应标签
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
# 添加默认标签(trl和gkd)
tags.update(self._tag_names)
# GKD论文引用格式
citation = textwrap.dedent("""\
@inproceedings{agarwal2024on-policy,
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
year = 2024,
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
publisher = {OpenReview.net},
url = {https://openreview.net/forum?id=3zKtaqxLhW},
}""")
# 生成模型卡片内容
model_card = generate_model_card(
base_model=base_model, # 基座模型
model_name=model_name, # 模型名称
hub_model_id=self.hub_model_id, # Hub上的模型ID
dataset_name=dataset_name, # 训练数据集
tags=tags, # 标签
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, # wandb实验URL
comet_url=get_comet_experiment_url(), # Comet实验URL
trainer_name="GKD", # 训练器名称
trainer_citation=citation, # 引用
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", # 论文标题
paper_id="2306.13649", # 论文ID(arXiv或OpenReview)
)
# 保存模型卡片到输出目录
model_card.save(os.path.join(self.args.output_dir, "README.md"))
GKDConfig
from dataclasses import dataclass, field
from typing import Any, Optional
from transformers import TrainingArguments # 导入Hugging Face的训练参数基类
from .sft_config import SFTConfig # 导入监督微调(SFT)的配置类(GKDConfig的父类)
@dataclass
class GKDConfig(SFTConfig):
"""
广义知识蒸馏(Generalized Knowledge Distillation, GKD)训练器的配置类。
此类仅包含GKD训练特有的参数,完整的训练参数请参考`transformers.TrainingArguments`和`SFTConfig`的文档。
参数说明:
temperature (`float`, 可选, 默认值 `0.9`):
采样温度。温度越高,生成的结果随机性越强。
lmbda (`float`, 可选, 默认值 `0.5`):
控制学生自生成样本比例的参数(即on-policy学习中,使用学生自己生成的输出进行训练的比例)。
beta (`float`, 可选, 默认值 `0.5`):
广义Jensen-Shannon散度(JSD)损失的插值系数,范围在`0.0`到`1.0`之间。
当beta=0.0时,损失退化为传统KL散度(学生模仿教师);当beta=1.0时,损失为反向KL散度(教师模仿学生)。
max_new_tokens (`int`, 可选, 默认值 `128`):
每次生成的最大token数量。
teacher_model_name_or_path (`str` 或 `None`, 可选, 默认值 `None`):
教师模型的名称或路径。若为`None`,则教师模型与当前训练的模型相同。
teacher_model_init_kwargs (`dict[str, Any]` 或 `None`, 可选, 默认值 `None`):
从字符串实例化教师模型时,传递给`AutoModelForCausalLM.from_pretrained`的关键字参数。
disable_dropout (`bool`, 可选, 默认值 `True`):
是否禁用模型中的dropout层(蒸馏中常用,以减少随机性,稳定训练)。
seq_kd (`bool`, 可选, 默认值 `False`):
是否执行序列级蒸馏(Sequence-Level KD),可视为在教师生成的输出上进行监督微调。
"""
# 扩展有效字典字段:在TrainingArguments的基础上添加教师模型的初始化参数
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
temperature: float = field(
default=0.9,
metadata={"help": "采样温度。温度越高,生成的结果随机性越强。"},
)
lmbda: float = field(
default=0.5,
metadata={
"help": "控制学生自生成样本比例的参数(即on-policy学习中,使用学生自己生成的输出进行训练的比例)。"
},
)
beta: float = field(
default=0.5,
metadata={
"help": "广义Jensen-Shannon散度(JSD)损失的插值系数,范围在0.0到1.0之间。"
"当beta=0.0时,损失为KL散度(学生模仿教师);当beta=1.0时,损失为反向KL散度(教师模仿学生)。"
},
)
max_new_tokens: int = field(
default=128,
metadata={"help": "每次生成的最大token数量。"},
)
teacher_model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "教师模型的名称或路径。若为None,教师模型将与当前训练的模型相同。"
},
)
teacher_model_init_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={
"help": "从字符串实例化教师模型时,传递给AutoModelForCausalLM.from_pretrained的关键字参数。"
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "是否禁用模型中的dropout层(蒸馏中常用以稳定训练)。"},
)
seq_kd: bool = field(
default=False,
metadata={
"help": "是否执行序列级蒸馏(可视为在教师生成的输出上进行监督微调)。"
},
)
def __post_init__(self):
"""初始化后执行的方法:调用父类初始化逻辑,并验证参数合法性。"""
super().__post_init__() # 调用父类(SFTConfig)的初始化后处理逻辑
# 验证lmbda参数是否在[0, 1]范围内
if self.lmbda < 0.0 or self.lmbda > 1.0:
raise ValueError("lmbda参数必须在[0.0, 1.0]范围内。")
# 验证beta参数是否在[0, 1]范围内
if self.beta < 0.0 or self.beta > 1.0:
raise ValueError("beta参数必须在[0.0, 1.0]范围内。")