以下是一个使用 PyTorch 和 SwanLab 实现训练可视化监控的完整示例,以 MNIST 手写数字识别为例:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import swanlab
# 初始化 SwanLab 实验 (自动生成仪表盘)
swanlab.init(
experiment_name="MNIST_CNN",
description="Simple CNN on MNIST with SwanLab monitoring",
config={
"batch_size": 64,
"epochs": 10,
"learning_rate": 0.01,
"model": "CNN"
}
)
# 1. 数据准备
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=swanlab.config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
# 2. 定义 CNN 模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout = nn.Dropout(0.25)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return nn.functional.log_softmax(x, dim=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=swanlab.config.learning_rate)
# 3. 训练循环
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = nn.functional.nll_loss(output, target)
loss.backward()
optimizer.step()
# 实时记录每个batch的损失
if batch_idx % 100 == 0:
swanlab.log({"train_loss": loss.item()}, step=epoch * len(train_loader) + batch_idx)
# 打印日志到控制台
print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")
# 4. 测试函数
def test(epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
# 记录epoch级别的指标
swanlab.log({
"test_loss": test_loss,
"accuracy": accuracy,
"epoch": epoch
})
print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n")
# 5. 执行训练
for epoch in range(1, swanlab.config.epochs + 1):
train(epoch)
test(epoch)
print("训练完成!请在 https://swanlab.cn 查看可视化结果")
关键说明:
SwanLab 初始化:
swanlab.init() # 创建实验并设置跟踪参数
实时日志记录:
swanlab.log({"train_loss": loss.item()}) # 记录每个batch的损失
指标可视化:
swanlab.log({"accuracy": accuracy, "test_loss": test_loss}) # 记录测试指标
使用步骤:
- 安装依赖:
pip install torch torchvision swanlab
- 运行脚本:
python mnist_example.py
- 查看结果:
- 终端会自动打印监控链接(如:
SwanLab Experiment: https://swanlab.cn/[username]/MNIST_CNN/runs/[run_id]
) - 或在 SwanLab 官网 登录查看
- 终端会自动打印监控链接(如:
仪表盘功能:
实时监控:
- 训练损失曲线(每100个batch更新)
- 测试精度/损失曲线(每个epoch更新)
实验管理:
- 记录所有超参数(batch_size, lr等)
- 保存实验配置和系统环境
- 对比多次运行结果
自动分析:
- 训练过程动态可视化
- 指标变化趋势分析
- 性能指标汇总统计
通过这个示例,你可以实时:
- 监控训练损失下降趋势
- 观察模型在验证集的性能变化
- 分析不同超参数对结果的影响
- 比较多次实验的结果差异
SwanLab 会自动保存所有实验数据,即使训练中断也能恢复可视化结果。