Chapter5.4 Loading and saving model weights in PyTorch

发布于:2025-02-10 ⋅ 阅读:(45) ⋅ 点赞:(0)

5 Pretraining on Unlabeled Data

5.4 Loading and saving model weights in PyTorch

  • 训练LLM的计算成本很高,因此能够保存和加载LLM的权重至关重要。

  • 在PyTorch中,推荐的方式是通过将torch.save函数应用于.state_dict()方法来保存模型权重,即所谓的state_dict

    torch.save(model.state_dict(),"model.pth")
    

    我们可以将模型权重加载到新的 GPTModel 模型实例中

    model = GPTModel(GPT_CONFIG_124M)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
    model.eval();
    
  • 自适应优化器(如 AdamW)为每个模型权重存储额外的参数。AdamW 使用历史数据动态调整每个模型参数的学习率。如果没有这些参数,优化器会重置,模型可能会学习效果不佳,甚至无法正确收敛,这意味着模型将失去生成连贯文本的能力。使用 torch.save,我们可以保存模型和优化器的 state_dict 内容,如下所示

    torch.save({
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        }, 
        "model_and_optimizer.pth"
    )
    

    然后,我们可以通过以下方式恢复模型和优化器状态:首先通过 torch.load 加载保存的数据,然后使用 load_state_dict 方法:

    checkpoint = torch.load("model_and_optimizer.pth", weights_only=True)
    
    model = GPTModel(GPT_CONFIG_124M)
    model.load_state_dict(checkpoint["model_state_dict"])
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    model.train();
    


网站公告

今日签到

点亮在社区的每一天
去签到