Pytorch混合精度训练最佳实践

发布于:2025-07-28 ⋅ 阅读:(22) ⋅ 点赞:(0)

混合精度训练(Mixed Precision Training)是一种通过结合单精度(FP32)和半精度(FP16/FP8)计算来加速训练、减少显存占用的技术。它在保持模型精度的同时,通常能带来 2-3 倍的训练速度提升,并减少约 50% 的显存使用,是平衡训练效率与数值稳定性的核心技术,尤其在大模型训练中不可或缺。
以下从GradScaler 底层逻辑、避坑技巧(含 NaN 解决方案)、PyTorch Lightning 实战三个维度展开,结合实战细节提供更深入的最佳实践。

一、核心思想

混合精度训练的核心原理
传统训练使用 32 位浮点数(FP32)存储参数和计算梯度,但研究发现:

  • 模型参数和激活值对精度要求较高(需 FP32)
  • 梯度计算和反向传播对精度要求较低(可用 FP16)

混合精度训练的核心逻辑:

  • 用 FP16 执行大部分计算(前向 / 反向传播),加速运算并减少显存
  • 用 FP32 保存模型参数和优化器状态,确保数值稳定性
  • 通过 “损失缩放”(Loss Scaling)解决 FP16 梯度下溢问题

二、GradScaler 缩放机制:从原理到细节

GradScaler 的核心是解决 FP16 梯度下溢问题,但缩放逻辑并非简单的固定倍数,而是动态自适应的。理解其底层逻辑能更好地控制训练稳定性。

  1. 核心流程

    • 缩放损失:将损失乘以一个缩放因子(初始值通常为 2^16),使梯度按比例放大,避免下溢;
    • 反向传播:用缩放后的损失计算梯度(FP16);
    • 梯度修正:将梯度除以缩放因子,恢复真实梯度值;
    • 参数更新:用修正后的梯度更新参数(FP32)。
  2. 缩放因子的动态调整逻辑
    GradScaler 维护一个初始缩放因子(默认2^16),并根据每次迭代的梯度是否溢出动态调整:

    • 无溢出:若连续多次(默认2000次)未溢出,缩放因子乘以growth_factor(默认1.0001),逐步放大以更充分利用 FP16 范围;
    • 有溢出:缩放因子乘以backoff_factor(默认0.5),并跳过本次参数更新(避免错误梯度影响模型)
    # 查看当前缩放因子(调试用)
    print(scaler.get_scale())  # 输出当前缩放因子值
    
  3. PyTorch 原生混合精度训练实现

import torch
import torch.nn as nn
from torch.optim import Adam

# 1. 初始化模型、损失函数、优化器
model = nn.Linear(10, 2).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-3)

# 2. 初始化GradScaler(关键)
scaler = torch.cuda.amp.GradScaler()

# 3. 训练循环
for inputs, labels in train_loader:
    inputs, labels = inputs.cuda(), labels.cuda()
    
    # 清零梯度
    optimizer.zero_grad()
    
    # 前向传播:用FP16计算(torch.cuda.amp.autocast上下文)
    with torch.cuda.amp.autocast():  # 自动将FP32操作转为FP16(支持的操作)
        outputs = model(inputs)
        loss = criterion(outputs, labels)  # 损失仍以FP32计算(更稳定)
    
    # 反向传播:缩放损失,计算梯度
    scaler.scale(loss).backward()  # 缩放损失,避免梯度下溢
    
    # 梯度裁剪(可选,防止梯度爆炸)
    scaler.unscale_(optimizer)  # 先将梯度恢复(除以缩放因子)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # 裁剪梯度
    
    # 更新参数:用修正后的梯度
    scaler.step(optimizer)  # 仅当梯度未溢出时更新(内部会判断)
    
    # 更新缩放因子(动态调整)
    scaler.update()  # 若本次无溢出,增大因子;若溢出,减小因子并跳过本次更新

关键细节:

  • torch.cuda.amp.autocast:自动将支持 FP16 的操作转为 FP16(如卷积、线性层),不支持的仍用 FP32(如 softmax);
    • 在 torch.cuda.amp.autocast 上下文(混合精度训练)中,损失函数默认以 FP32 计算以提高数值稳定性:
      • 损失值通常较小:模型输出的损失(如交叉熵、MSE)数值范围往往较小(例如 1e-3 ~ 10),FP16 的小数精度有限(仅能精确表示 1e-4 以上的数值),若用 FP32 计算可避免精度损失。
      • 梯度计算的起点:损失是反向传播的起点,其精度直接影响梯度的准确性。FP32 损失能提供更稳定的梯度初始值,减少后续梯度下溢 / 爆炸的风险。
    • 实际上,autocast 上下文会自动判断操作是否适合 FP16:
      • 对于计算密集型操作(如卷积、线性层),自动转为 FP16 以提升速度;
      • 对于精度敏感操作(如损失计算、softmax、 BatchNorm),默认保留 FP32 或自动回退到 FP32。
  • scaler.step(optimizer):内部会检查梯度是否溢出(若梯度为 inf/NaN 则跳过更新);
  • 保存 / 加载 checkpoint 时,需同步保存scaler.state_dict(),否则缩放因子状态丢失会导致训练不稳定。

三、避免 Loss 为 NaN 的实践方法

Loss 为 NaN 通常源于数值不稳定(梯度爆炸 / 下溢、极端值计算),可从以下方面解决:

  1. 控制梯度范围

    • 梯度裁剪:如上述代码中clip_grad_norm_,限制梯度 L2 范数(建议 max_norm=1.0~10.0);
    • 调整 GradScaler 参数:通过growth_factor(默认 1.0001,增大因子的速度)和backoff_factor(默认 0.5,减小因子的速度)控制缩放因子,避免过度放大导致梯度爆炸。
     # 更保守的scaler配置(适合易溢出场景)
     scaler = torch.cuda.amp.GradScaler(
         init_scale=2.**10,  # 初始缩放因子(默认2^16,减小初始值更保守)
         growth_factor=1.1,
         backoff_factor=0.8
     )
    
  2. 检查数据与标签

    • 确保输入数据无异常值(如 inf/NaN),可通过torch.isnan(inputs).any()检查;
    • 标签需在有效范围内(如分类任务标签不超过类别数)。
  3. 调整关键计算为 FP32

    • 部分操作在 FP16 下易不稳定(如 softmax、交叉熵、BatchNorm),可强制用 FP32 计算:
     with torch.cuda.amp.autocast():
         outputs = model(inputs)
         # 强制损失计算用FP32
         loss = criterion(outputs.float(), labels)  # 将outputs转为FP32
    
  4. 降低学习率

    • 过大的学习率可能导致参数更新幅度过大,引发数值爆炸。建议初始学习率比 FP32 训练时小(如缩小 10 倍),再逐步调整。
  5. 禁用FP16的 BatchNorm/LayerNorm/RevIN等

    • BatchNorm 在 FP16 下可能因均值 / 方差计算精度不足导致不稳定,可强制用 FP32:
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.float()  # 将BatchNorm参数转为FP32
    

四、PyTorch Lightning 实现混合精度训练

PyTorch Lightning 通过封装torch.cuda.amp,大幅简化混合精度训练流程,只需在Trainer中设置precision参数。

import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __len__(self): return 1000
    def __getitem__(self, idx): return torch.randn(10), torch.randint(0, 2, (1,)).item()

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(10, 2)
        self.criterion = nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

# 训练器配置(关键:设置precision)
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    devices=1,
    precision=16,  # 16: FP16混合精度;"bf16": BF16混合精度;32: 纯FP32
    gradient_clip_val=1.0  # 梯度裁剪(防NaN)
)

# 启动训练
model = LitModel()
train_loader = DataLoader(MyDataset(), batch_size=32)
trainer.fit(model, train_loader)

高级配置

  • 自定义 GradScaler:若需调整 scaler 参数,可重写configure_gradient_clipping:
    def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
      scaler = self.trainer.scaler  # 获取Lightning内部的scaler
      scaler._init_scale = 2.** 10  # 调整初始缩放因子
      super().configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm)
    

五、模型层面避坑点

在混合精度训练中,模型层面需要针对性修改的模块主要是那些对数值精度敏感、容易在 FP16(半精度)下产生不稳定的组件。这些模块的特性(如动态范围大、依赖精确统计量计算)使其在低精度下容易出现梯度下溢、数值爆炸或精度损失,需要特殊处理。以下是需要重点关注和修改的模块及具体策略:

  1. 归一化层(BatchNorm/GroupNorm/LayerNorm)
    归一化层依赖均值(mean)和方差(variance)的统计量计算,这些值通常较小(接近 0),在 FP16 下可能因精度不足导致数值不稳定(如方差变为 0,引发除以 0 错误)。此外,归一化层的缩放参数(gamma/beta)若动态范围大,也容易在 FP16 下溢出。
    修改策略:
  • 强制使用 FP32 参数和统计量:
    • 将归一化层的权重(weight)、偏置(bias)及运行统计量(running_mean/running_var)保留在 FP32 精度,仅计算过程中的中间结果可使用 FP16。
      for m in model.modules():
          if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm, nn.LayerNorm)):
              m.float()  # 强制参数为FP32
              # 对于LayerNorm,可进一步固定eps避免过小值
              if isinstance(m, nn.LayerNorm):
                  m.eps = 1e-5  # 增大eps,避免除以接近0的方差
      
  1. 激活函数(Softmax/LogSoftmax/Sigmoid 等)
  • Softmax/LogSoftmax 数值稳定化:
    减去输入的最大值(不改变结果),避免指数运算溢出:
    def safe_softmax(x, dim=-1):
        x = x - x.max(dim=dim, keepdim=True).values  # 数值稳定技巧
        return torch.softmax(x, dim=dim)
    
    # 替换模型中的原生Softmax
    model.classifier.softmax = safe_softmax  # 假设模型最后一层有softmax
    
  • 激活函数输入截断:
    对输入范围进行限制,避免极端值进入激活函数:
    class SafeSigmoid(nn.Module):
        def forward(self, x):
            x = torch.clamp(x, min=-10, max=10)  # 截断到[-10,10],避免梯度下溢
            return torch.sigmoid(x)
    
  • 优先使用 FP32 计算激活:
    对敏感激活函数,强制在 FP32 下计算:
    with torch.cuda.amp.autocast():
        x = model.conv(x)  # FP16计算
        x = x.float()  # 转为FP32
        x = torch.softmax(x, dim=-1)  # FP32下计算softmax
    
  1. 损失函数(CrossEntropy/MSE 等)
    • 损失计算强制 FP32:
      将模型输出转为 FP32 后再计算损失,避免低精度导致的不稳定:
    with torch.cuda.amp.autocast():
        logits = model(inputs)  # FP16输出
        # 转为FP32计算损失
        loss = criterion(logits.float(), labels)
    
  • 损失值截断或缩放:
    对回归任务的 MSE 损失,可先缩放目标值到合理范围(如[-1,1]),或对损失值进行截断:
    def safe_mse_loss(pred, target):
        pred = torch.clamp(pred, min=-1000, max=1000)  # 截断预测值
        target = torch.clamp(target, min=-1000, max=1000)
        return torch.nn.functional.mse_loss(pred, target)
    
  1. 优化器
    • 优化器参数保留 FP32:
      PyTorch 默认会将优化器状态(如 Adam 的 m 和 v)存储为与参数相同的精度,需强制用 FP32:
    # 初始化优化器时指定参数为FP32
    optimizer = Adam([p.float() for p in model.parameters()], lr=1e-3)
    
  2. 模块修改优先级
    • 最高优先级:BatchNorm/GroupNorm(统计量敏感)、Softmax/LogSoftmax(易溢出);
    • 高优先级:Transformer 注意力机制(分数计算范围大)、交叉熵损失(log 运算敏感);
    • 中等优先级:Sigmoid 等激活函数(梯度易下溢)、优化器状态(需 FP32 累积);
    • 低优先级:卷积层 / 线性层(PyTorch 对其 FP16 支持较好,通常无需修改)。

通过针对性修改这些模块,可在混合精度训练中显著提升数值稳定性,避免因精度问题导致的 Loss 为 NaN 或模型收敛异常。实际应用中,建议结合训练日志(如监控缩放因子变化、梯度范围)逐步调整,找到最适合模型的配置。

六、PyTorch Lightning 中手动指定模块用 FP32 的方法

PyTorch Lightning(PL)虽然自动封装了混合精度逻辑,但仍支持手动控制特定模块的精度(如强制某层用 FP32)。核心思路是在模块的前向传播中显式转换数据类型,或局部禁用 autocast。

  1. 方法一:在模型层中显式转换为 FP32
    适用于需要对特定模块(如 BatchNorm、损失计算)强制 FP32 的场景,直接在模块的 forward 方法中转换输入 / 参数类型:
import pytorch_lightning as pl
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3)  # 常规卷积层(可FP16)
        self.bn = nn.BatchNorm2d(64)     # 对精度敏感的BN层(强制FP32)
        self.fc = nn.Linear(64, 10)      # 输出层(可FP16)

    def forward(self, x):
        # 1. 卷积层:默认FP16(PL的autocast生效)
        x = self.conv(x)
        
        # 2. BN层:强制FP32计算
        x = x.float()  # 输入转为FP32
        x = self.bn(x)  # BN层参数已在初始化时设为FP32(见下方)
        x = x.half()   # 转回FP16,不影响后续层
        
        # 3. 全连接层:默认FP16
        x = self.fc(x)
        return x

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MyModel()
        # 强制BN层参数为FP32(关键)
        for m in self.model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.float()  # 权重和统计量用FP32
        
        self.criterion = nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        
        # 2. 损失计算:强制FP32(类似原生PyTorch)
        loss = self.criterion(logits.float(), y)  # logits转为FP32
        
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
  1. 方法二:局部禁用 autocast 上下文
    若需要某段代码完全禁用混合精度(强制 FP32),可使用 torch.cuda.amp.autocast(enabled=False) 覆盖 PL 的全局设置:
class LitModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        # PL会自动开启autocast,但可局部禁用
        with torch.cuda.amp.autocast(enabled=False):  # 此范围内强制FP32
            # 例如:对敏感操作(如注意力分数计算)强制FP32
            logits = self.model(x.float())  # 输入转为FP32
            loss = self.criterion(logits, y)
        
        self.log("train_loss", loss)
        return loss
  1. 方法三:通过 PrecisionPlugin 自定义精度策略
    对于更复杂的场景(如部分模块用 FP16、部分用 FP32),可自定义 PrecisionPlugin 控制全局精度逻辑:
from pytorch_lightning.plugins import MixedPrecisionPlugin

# 自定义混合精度插件
class CustomMixedPrecisionPlugin(MixedPrecisionPlugin):
    def __init__(self):
        super().__init__(precision=16, scaler=torch.cuda.amp.GradScaler())

    def autocast_context_manager(self):
        # 全局默认开启autocast,但可在模型中局部禁用
        return torch.cuda.amp.autocast(dtype=torch.float16)

# 在Trainer中使用自定义插件
trainer = pl.Trainer(
    accelerator="gpu",
    precision=16,
    plugins=[CustomMixedPrecisionPlugin()],
    max_epochs=10
)
  1. 总结
    在大多数情况下,torch.cuda.amp.autocast 能够自动判断操作是否适合 FP16,无需手动干预(如强制类型转换)。PyTorch 对 autocast 的设计目标就是 “开箱即用”,通过内置的操作映射规则,自动为不同操作选择最优精度。
    autocast 内部维护了一份操作 - 精度映射表,对 PyTorch 原生算子(如卷积、线性层、激活函数等)做了精细化适配:
    • 优先 FP16 的操作:计算密集型且对精度不敏感的操作,如 nn.Conv2d、nn.Linear、nn.ReLU 等,这些操作在 FP16 下速度提升明显,且精度损失可接受。
    • 自动回退 FP32 的操作:精度敏感或易溢出的操作,如:
      • 损失函数(nn.CrossEntropyLoss、nn.MSELoss 等);
      • 归一化层(nn.BatchNorm、nn.LayerNorm 的内部统计量计算);
      • 数值不稳定的操作(torch.softmax、torch.log_softmax 等)。
      • 例如,当你在 autocast 上下文内调用 loss = criterion(outputs, labels) 时,即使 outputs 是 FP16,PyTorch 也会自动将其转换为 FP32 进行损失计算,再将损失以 FP32 输出 —— 整个过程无需手动干预。

尽管 autocast 设计得很智能,但在自定义操作或复杂模型结构中,可能出现自动适配不符合预期的情况,此时需要手动控制精度。常见场景包括:

  • 自定义算子或未被 autocast 覆盖的操作
    如果模型中包含 PyTorch 原生算子之外的自定义操作(如自研 CUDA 算子、特殊数学运算),autocast 可能无法识别,导致其默认使用 FP32(影响速度)或错误使用 FP16(导致精度问题)。比如在时间序列模型经常使用RevIN模块,该模块为了解决时序的分布漂移,如果不手动干预,torch.amp.autocast是不能很好处理的.
  • 模型中间层出现数值异常(如溢出 / 下溢)
    即使使用原生算子,某些特殊场景(如输入值范围极端、模型深度过深)可能导致中间层数值异常(如 FP16 下卷积输出突然变为 inf)。此时需对异常层手动干预。
  • 对精度有极致要求的场景
    某些任务(如医疗影像、高精度回归)对数值精度要求极高,即使 autocast 自动适配,也可能需要关键模块强制 FP32 以减少精度损失。

简言之,autocast 的自动适配是 “最优解”,手动干预仅作为 “异常修复手段”。实际开发中,建议先依赖自动适配,遇到问题再针对性调整。

七、Debug问题定位

在 PyTorch Lightning(PL)混合精度训练中出现 Loss 为 NaN,通常是数值不稳定累积的结果(而非突然出现)。定位问题需要从数据→模型→训练机制→混合精度配置逐步排查,结合 PL 的调试工具可高效定位根因。以下是具体的 debug 流程和操作方法:

7.1 复现与简化:缩小问题范围

首先通过简化实验快速复现问题,排除偶然因素:

  1. 缩短训练流程:用fast_dev_run=True让模型快速跑 1 个 batch 的训练 + 验证,观察是否立即出现 NaN(排除多 epoch 累积效应)。
trainer = pl.Trainer(
    fast_dev_run=True,  # 快速验证流程
    precision=16,
    accelerator="gpu"
)
  1. 固定随机种子:确保结果可复现,排除数据随机波动导致的偶然 NaN:
pl.seed_everything(42, workers=True)  # 固定所有随机种子
  1. 减少 batch size:若大 batch 下出现 NaN,尝试batch_size=1,判断是否与数据分布不均相关。

7.2 模型层面:定位敏感模块的数值不稳定

混合精度下,模型中对精度敏感的模块(如归一化层、激活函数、注意力分数)易成为 NaN 源头。可通过模块输出监控和精度隔离测试定位问题。

  1. 监控模块输出范围
    在模型的关键模块后添加输出范围监控,追踪数值是否异常膨胀:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.conv(x)
        # 监控卷积层输出(易数值膨胀)
        if torch.isnan(x).any():
            print("Conv层输出NaN!")
        self.log("conv_max", x.max().item())  # 记录到PL日志
        
        x = self.bn(x)
        # 监控BN层输出(方差为0会导致NaN)
        self.log("bn_max", x.max().item())
        
        x = self.relu(x)
        self.log("relu_max", x.max().item())
        return x

重点关注:

  • 卷积 / 线性层输出是否突然增大(如超过1e4);
  • BN 层输出是否出现inf(可能因方差为 0 导致);
  • 激活函数(如 ReLU)后是否仍有极端值(说明激活未起到截断作用)。
  1. 隔离测试:逐步禁用模块的 FP16
    若怀疑某模块在 FP16 下不稳定,可强制其用 FP32 计算,观察是否消除 NaN(即 “精度隔离测试”):

7.4 混合精度机制:检查 GradScaler 与精度切换

PL 的混合精度依赖GradScaler和autocast,若这两者配置不当,会直接导致梯度溢出或参数更新异常。

  1. 监控 GradScaler 的缩放因子
    缩放因子(scale)的异常变化是数值不稳定的重要信号:
    • 若缩放因子突然从1e4暴跌到1e-2,说明梯度频繁溢出(Scaler 被迫减小因子);
    • 若缩放因子持续为初始值(如2^16)且无增长,可能存在梯度下溢。
      在 PL 中可通过Trainer的log_every_n_steps监控缩放因子:
    # 在LightningModule中添加scaler监控
    def training_step(self, batch, batch_idx):
        # ... 前向传播 ...
        self.log("scaler_scale", self.trainer.scaler.get_scale(), prog_bar=True)
        return loss
    
    应对措施:若缩放因子暴跌,说明当前学习率可能过大,可尝试降低学习率(如缩小为原来的 1/10)。
  2. 检查梯度是否溢出
    PL 的Trainer可启用梯度异常检测,定位梯度溢出的具体参数:
    trainer = pl.Trainer(
        precision=16,
        detect_anomaly=True,  # 启用梯度异常检测(会降低速度,仅调试用)
        accelerator="gpu"
    )
    
    启用后,若梯度出现inf/NaN,会打印具体的参数名称(如conv.weight),直接定位到异常模块。
  3. 验证是否是混合精度本身的问题
    对比纯 FP32 训练结果,判断是否由混合精度机制导致:
    • 若纯 FP32 训练无 NaN,说明问题与混合精度的数值敏感相关;
    • 若纯 FP32 仍有 NaN,说明问题在模型或数据本身(如学习率过大、数据异常)。

7.5 训练配置:排查学习率与梯度裁剪

混合精度下,梯度经过缩放后,实际有效学习率可能被放大,导致参数更新幅度过大,最终引发 NaN。

  1. 降低学习率并监控参数更新幅度
    混合精度训练的初始学习率建议为纯 FP32 的 1/2~1/10(因梯度缩放可能等效放大学习率)。可在configure_optimizers中临时降低学习率:
def configure_optimizers(self):
    optimizer = Adam(self.parameters(), lr=1e-4)  # 从1e-3降至1e-4
    return optimizer

同时监控参数更新幅度(更新前后的 L2 距离):

def training_step(self, batch, batch_idx):
    # ... 前向传播 ...
    loss.backward()  # 手动触发反向传播(便于调试)
    if batch_idx % 10 == 0:  # 每10个batch检查一次
        for name, param in self.named_parameters():
            if param.grad is not None:
                update_norm = (param.grad * self.optimizers().param_groups[0]['lr']).norm()
                self.log(f"update_{name}", update_norm, prog_bar=True)
    return loss

若更新幅度超过1e2,说明学习率可能过大。
2. 检查梯度裁剪是否生效
PL 的gradient_clip_val参数若配置不当,可能导致梯度未被有效裁剪:

trainer = pl.Trainer(
    precision=16,
    gradient_clip_val=1.0,  # 裁剪梯度L2范数至1.0
    gradient_clip_algorithm="norm",  # 推荐用L2范数裁剪
)

可在training_step中验证裁剪效果:

def training_step(self, batch, batch_idx):
    # ... 前向传播与反向传播 ...
    self.manual_backward(loss)  # 手动反向传播(PL默认自动处理,显式写出方便调试)
    # 裁剪前检查梯度范数
    grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0, norm_type=2)
    self.log("grad_norm_before_clip", grad_norm, prog_bar=True)
    self.optimizers().step()
    self.optimizers().zero_grad()

若grad_norm_before_clip持续超过1.0,说明裁剪未生效(可能是 PL 版本 bug,可尝试手动裁剪)。


网站公告

今日签到

点亮在社区的每一天
去签到