【Pytorch常用模块总结】

发布于:2025-03-05 ⋅ 阅读:(111) ⋅ 点赞:(0)

数据准备

数据集预处理
  • torchvision.transforms
    • transforms.ToTensor():将 PIL 图像或 NumPy 数组转换为张量
    • transforms.Normalize(mean, std):标准化数据,指定均值和标准差
    • transforms.Resize(size):调整图像大小
    • transforms.RandomCrop(size):随机裁剪
    • transforms.RandomHorizontalFlip():随机水平翻转,用于数据增强
    • transforms.Compose(transforms_list):组合多个变换
数据集的导入
  • 自建数据集
    • torch.utils.data.Dataset
      • __init__:初始化数据集(如加载文件路径、标签)
      • __len__:返回数据集大小
      • __getitem__:定义如何获取单个样本及其标签
      • 可搭配 torchvision.transforms 进行预处理
  • 通用数据集
    • torchvision.datasets
      • 示例:torchvision.datasets.MNIST(root='./data', train=True, download=True)
      • 参数:train=True/False 区分训练集和测试集
数据集的加载
  • torch.utils.data.DataLoader
    • 参数:
      • batch_size:批次大小
      • shuffle=True/False:是否打乱数据(训练 True,测试 False)
      • num_workers:多线程加载数据的线程数
      • drop_last=True:丢弃最后一个不完整批次

定义模型

  • torch.nn
    • nn.Module:自定义模型需继承并实现 forward 方法
    • 常用层
      • nn.Linear(in_features, out_features):全连接层
      • nn.Conv2d(in_channels, out_channels, kernel_size):二维卷积层
      • nn.MaxPool2d(kernel_size):最大池化层
    • 激活函数
      • nn.ReLU()nn.Sigmoid()
    • 正则化和归一化
      • nn.Dropout(p):随机丢弃,防止过拟合
      • nn.BatchNorm2d(num_features):批归一化
    • nn.Sequential:快速构建简单网络

定义损失函数

  • torch.nn
    • nn.CrossEntropyLoss():交叉熵损失(含 Softmax),多分类任务
    • nn.MSELoss():均方误差,回归任务
    • nn.BCELoss() / nn.BCEWithLogitsLoss():二分类任务
    • 根据任务选择合适的损失函数

定义优化器

  • torch.optim
    • optim.SGD(model.parameters(), lr, momentum):随机梯度下降
    • optim.Adam(model.parameters(), lr, weight_decay):Adam 优化器
    • 参数:
      • lr:学习率
      • momentum:动量法参数(SGD)
      • weight_decay:L2 正则化参数
    • 学习率调度器
      • torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma):按步长衰减
      • torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer):根据指标调整

训练模型

  • torch.nn.Module
    • model.to(device):将模型移到 GPU/CPU,device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    • model.train():进入训练模式
    • torch.no_grad():禁用梯度计算(验证/测试时使用)
    • 模型保存与加载
      • torch.save(model.state_dict(), 'path.pth'):保存模型参数
      • torch.save(model, 'path.pth'):保存整个模型
      • model.load_state_dict(torch.load('path.pth')):加载模型参数
  • 训练流程
    • 计算损失 → 梯度置零 → 反向传播 → 更新参数
      • loss = loss_fn(output, y)
      • optimizer.zero_grad()
      • loss.backward()
      • optimizer.step()
    • 示例:
      for epoch in range(num_epochs):
          model.train()
          for batch_x, batch_y in data_loader:
              batch_x, batch_y = batch_x.to(device), batch_y.to(device)
              output = model(batch_x)
              loss = loss_fn(output, batch_y)
              optimizer.zero_grad()
              loss.backward()
              optimizer.step()
      

网站公告

今日签到

点亮在社区的每一天
去签到