如果你在 PyTorch 中只做「调包侠」,那么永远只是在外围打转;只有把「模型定义 → 修改 → 保存/加载」整条链路打通,才算真正拥有了炼丹炉的钥匙。
本文把官方教程 5.1–5.4 浓缩成一篇逻辑闭环的实战笔记,力求“看完即可落地”。
1. 为什么要有“模型工程化”思维?
阶段 | 痛点举例 | 本章解法 |
---|---|---|
快速验证 | 一行行手写 100 层 CNN? | Sequential / 模型块 |
需求变更 | ResNet50 输出从 1000 → 10 类 | 局部层替换 / 外部输入输出 |
训练中断 | 断电后需从头再来 | 断点续训 |
部署迁移 | 8 卡训练 → 1 卡推理报错 | 统一权重前缀 |
2. 模型定义:三种姿势,按需选择
2.1 Sequential —— 极简线性堆叠
net = nn.Sequential(
nn.Conv2d(3, 64, 3),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, 10)
)
适用:快速 PoC、网络无分支。
2.2 ModuleList / ModuleDict —— 乐高式复用
class TinyResNet(nn.Module):
def __init__(self, n_blocks=4):
super().__init__()
self.blocks = nn.ModuleList([
Bottleneck(64, 64) for _ in range(n_blocks)
])
def forward(self, x):
for blk in self.blocks:
x = blk(x)
return x
适用:重复单元、需要动态深度。
3. 模型修改:三大高频需求一次讲透
以 torchvision.models.resnet50()
为例。
需求 | 关键 API / 技巧 | 代码片段 |
---|---|---|
改输出类别 | 直接替换 fc 层 |
net.fc = nn.Linear(2048, 10) |
加额外输入 | forward 里 torch.cat |
x = torch.cat([net(x), add_var.unsqueeze(1)], 1) |
多输出/中间特征 | 修改 forward 的 return | return out, feature |
所有修改只需继承
nn.Module
并重写__init__
与forward
,无需动原始源码。
4. 模型保存与加载:单卡/多卡一次说清
4.1 存什么?
方式 | 命令 | 优缺点 |
---|---|---|
仅权重 | torch.save(model.state_dict(), path) |
轻量、跨环境兼容 |
整个模型 | torch.save(model, path) |
含结构,但依赖原始类定义和 Python 版本 |
实战建议:99% 场景只存权重。
4.2 单卡 ↔ 多卡权重前缀问题
- 多卡训练会引入
"module."
前缀 - 通用解法:存权重时统一存
model.module.state_dict()
,或加载时 strip 前缀:
state = torch.load('multi_gpu.pth')
new_state = {k[7:]: v for k, v in state.items()} # 去掉 'module.'
model.load_state_dict(new_state)
4.3 断点续训:把训练状态一起打包
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
'best_acc': best_acc
}, 'checkpoint.pth')
# 恢复
ckpt = torch.load('checkpoint.pth')
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
start_epoch = ckpt['epoch'] + 1
5. 一条完整的开发流水线示例
6. 小结 & 行动清单
任务场景 | 立即能做的最小行动 |
---|---|
快速搭 baseline | 用 nn.Sequential 10 行内出模型 |
迁移学习 | 把 ResNet50 的 fc 替换成你的类别数 |
断电续训练 | 把 optimizer & epoch 一起写进 checkpoint |
8 卡训练 → 单卡推理 | 保存 model.module.state_dict() |
记住一句话:权重是模型的灵魂,结构是容器;容器可以重建,灵魂必须妥善保存。
参考资料:
《深入浅出PyTorch》第5章 5.1–5.4(DatawhaleChina 团队)
官方文档:torch.save
/ torch.load
/ nn.DataParallel