深度学习篇---模型参数保存

发布于:2025-08-29 ⋅ 阅读:(18) ⋅ 点赞:(0)

在深度学习模型训练和部署过程中,模型保存是一个关键环节。不同框架在模型保存的实现上既有相似之处,也有各自的特点。下面详细介绍 PyTorch、TensorFlow 和 PaddlePaddle 中模型保存的代码及保存内容:

1. PyTorch

PyTorch 提供了灵活的模型保存方式,主要通过torch.save()函数实现,可保存模型结构、参数或训练状态。

(1)保存模型参数(推荐)

仅保存模型的参数(权重和偏置),不包含模型结构,文件体积较小。

import torch
import torch.nn as nn

# 定义示例模型
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# 保存模型参数(状态字典,state_dict)
torch.save(model.state_dict(), "model_params.pth")
  • 保存内容:模型的state_dict,是一个字典,层名称对应参数的张量
  • 用途:适用于训练中断后恢复训练,或在已知模型结构的情况下加载参数。
(2)保存完整模型

保存整个模型(包括结构和参数),但可能存在兼容性问题(如不同 PyTorch 版本或 Python 环境)。

# 保存完整模型
torch.save(model, "full_model.pth")
  • 保存内容:模型的类结构、参数及其他属性(如训练配置)。
  • 注意:不推荐用于跨环境部署,可能因类定义变化导致加载失败。
(3)保存训练过程状态(断点续训)

保存模型参数、优化器状态、epoch 等信息,用于中断后继续训练。

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch = 10
loss = 0.123

# 保存训练状态
checkpoint = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": epoch,
    "loss": loss
}
torch.save(checkpoint, "checkpoint.pth")
  • 保存内容:模型参数、优化器参数(如动量、学习率)、当前训练轮次、损失值等。

2. TensorFlow(Keras)

TensorFlow(尤其是 Keras 接口)提供了多种模型保存方式,支持 SavedModel 格式(推荐)和 HDF5 格式。

(1)保存完整模型(SavedModel 格式,推荐)

SavedModel 是 TensorFlow 的标准格式,包含模型结构、参数、计算图等,兼容性强。

  • 保存内容
    • 模型结构(网络层、输入输出形状);
    • 所有参数(权重和偏置);
    • 训练配置(优化器、损失函数、 metrics);
    • 计算图(用于部署到 TensorFlow Serving、移动端等)。
  • 用途:模型部署、跨平台使用(如 TensorFlow Lite、TensorRT)。
(2)保存为 HDF5 格式

保存模型结构和参数到单一文件,适用于简单场景。

# 保存为HDF5格式
model.save("model.h5")
  • 保存内容:模型结构(JSON 格式)和参数(二进制),但不包含计算图细节。
  • 注意:对复杂模型(如自定义层、控制流)的兼容性较差。
(3)保存权重(仅参数)

仅保存模型参数,需已知模型结构才能加载。

# 保存权重
model.save_weights("model_weights.h5")
  • 保存内容:各层的权重张量,不包含模型结构。
(4)训练过程保存(Checkpoint)

通过ModelCheckpoint回调保存训练过程中的模型状态。

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="training_checkpoint",
    save_weights_only=False,  # 是否仅保存权重
    save_best_only=True,      # 仅保存性能最好的模型
    monitor="val_loss"        # 监控指标
)

# 训练时使用回调
model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint_callback])
  • 保存内容:根据配置,可保存完整模型或仅权重,支持按指标(如验证集损失)保存最优模型。

3. PaddlePaddle

PaddlePaddle 的模型保存逻辑与 PyTorch 类似,主要通过paddle.save()Model.save()实现。

(1)保存模型参数(推荐)

仅保存模型参数,需结合模型结构加载。

import paddle
from paddle.nn import Linear

# 定义示例模型
class SimpleModel(paddle.nn.Layer):
    def __init__(self):
        super().__init__()
        self.fc = Linear(in_features=10, out_features=2)
    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# 保存模型参数
paddle.save(model.state_dict(), "model_params.pdparams")
  • 保存内容:模型的state_dict,键为层名称,值为参数张量。
(2)保存完整模型

保存模型结构和参数,方便直接加载使用。

# 保存完整模型
paddle.Model(model).save("full_model")
  • 保存内容:模型结构(__model__文件)和参数(*.pdparams),支持跨环境加载。
(3)保存训练过程状态(断点续训)

保存模型参数、优化器状态、训练轮次等。

optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.001)
epoch = 10
loss = 0.123

# 保存训练状态
checkpoint = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": epoch,
    "loss": loss
}
paddle.save(checkpoint, "checkpoint.pdparams")
  • 保存内容:模型参数、优化器参数(如学习率、动量)、训练进度等。

总结

框架 保存类型 核心函数 / 方法 主要保存内容
PyTorch 仅参数 torch.save(model.state_dict(), ...) 模型参数(state_dict)
完整模型 torch.save(model, ...) 模型结构 + 参数
训练状态(断点续训) torch.save(checkpoint_dict, ...) 模型参数 + 优化器状态 + 训练进度
TensorFlow 完整模型(推荐) model.save("saved_model") 结构 + 参数 + 计算图 + 训练配置
HDF5 格式 model.save("model.h5") 结构 + 参数(兼容性有限)
仅参数 model.save_weights(...) 各层权重
训练过程检查点 ModelCheckpoint回调 按配置保存模型或权重(支持最优模型选择)
PaddlePaddle 仅参数 paddle.save(model.state_dict(), ...) 模型参数(state_dict)
完整模型 paddle.Model(model).save(...) 结构 + 参数
训练状态(断点续训) paddle.save(checkpoint_dict, ...) 模型参数 + 优化器状态 + 训练进度

实际应用中,仅保存参数通常是最灵活和高效的方式(需配合模型结构加载);完整模型适合快速部署但需注意兼容性;训练状态保存则用于中断后恢复训练。


网站公告

今日签到

点亮在社区的每一天
去签到