PyTorch 模型开发全栈指南:从定义、修改到保存的完整闭环

发布于:2025-07-23 ⋅ 阅读:(18) ⋅ 点赞:(0)

如果你在 PyTorch 中只做「调包侠」,那么永远只是在外围打转;只有把「模型定义 → 修改 → 保存/加载」整条链路打通,才算真正拥有了炼丹炉的钥匙。
本文把官方教程 5.1–5.4 浓缩成一篇逻辑闭环的实战笔记,力求“看完即可落地”。


1. 为什么要有“模型工程化”思维?

阶段 痛点举例 本章解法
快速验证 一行行手写 100 层 CNN? Sequential / 模型块
需求变更 ResNet50 输出从 1000 → 10 类 局部层替换 / 外部输入输出
训练中断 断电后需从头再来 断点续训
部署迁移 8 卡训练 → 1 卡推理报错 统一权重前缀

2. 模型定义:三种姿势,按需选择

2.1 Sequential —— 极简线性堆叠

net = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(64, 10)
)

适用:快速 PoC、网络无分支。

2.2 ModuleList / ModuleDict —— 乐高式复用

class TinyResNet(nn.Module):
    def __init__(self, n_blocks=4):
        super().__init__()
        self.blocks = nn.ModuleList([
            Bottleneck(64, 64) for _ in range(n_blocks)
        ])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x

适用:重复单元、需要动态深度。


3. 模型修改:三大高频需求一次讲透

torchvision.models.resnet50() 为例。

需求 关键 API / 技巧 代码片段
改输出类别 直接替换 fc net.fc = nn.Linear(2048, 10)
加额外输入 forward 里 torch.cat x = torch.cat([net(x), add_var.unsqueeze(1)], 1)
多输出/中间特征 修改 forward 的 return return out, feature

所有修改只需继承 nn.Module 并重写 __init__forward,无需动原始源码。


4. 模型保存与加载:单卡/多卡一次说清

4.1 存什么?

方式 命令 优缺点
仅权重 torch.save(model.state_dict(), path) 轻量、跨环境兼容
整个模型 torch.save(model, path) 含结构,但依赖原始类定义和 Python 版本

实战建议:99% 场景只存权重。

4.2 单卡 ↔ 多卡权重前缀问题

  • 多卡训练会引入 "module." 前缀
  • 通用解法:存权重时统一存 model.module.state_dict(),或加载时 strip 前缀:
state = torch.load('multi_gpu.pth')
new_state = {k[7:]: v for k, v in state.items()}  # 去掉 'module.'
model.load_state_dict(new_state)

4.3 断点续训:把训练状态一起打包

torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
    'epoch': epoch,
    'best_acc': best_acc
}, 'checkpoint.pth')

# 恢复
ckpt = torch.load('checkpoint.pth')
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
start_epoch = ckpt['epoch'] + 1

5. 一条完整的开发流水线示例

Sequential / ModuleList
改层
加输入
加输出
单卡
多卡
定义网络
训练
需求变更
局部替换 fc
重写 forward + cat
return 多个值
保存权重 state_dict
部署环境
直接 load
DataParallel + strip 前缀
继续训练 / 推理

6. 小结 & 行动清单

任务场景 立即能做的最小行动
快速搭 baseline nn.Sequential 10 行内出模型
迁移学习 把 ResNet50 的 fc 替换成你的类别数
断电续训练 把 optimizer & epoch 一起写进 checkpoint
8 卡训练 → 单卡推理 保存 model.module.state_dict()

记住一句话:权重是模型的灵魂,结构是容器;容器可以重建,灵魂必须妥善保存。


参考资料
《深入浅出PyTorch》第5章 5.1–5.4(DatawhaleChina 团队)
官方文档:torch.save / torch.load / nn.DataParallel


网站公告

今日签到

点亮在社区的每一天
去签到