Pytorch:模块(Module类)

发布于:2024-04-26 ⋅ 阅读:(19) ⋅ 点赞:(0)


在 PyTorch 中,Module 是一个非常核心的概念,它是所有神经网络层和模型的基础类。torch.nn.Module 是构建所有神经网络的基类,在 PyTorch 中非常重要,因为它提供了网络的组织架构,并封装了权重、梯度的管理、模型参数的更新等功能。

PyTorch 中的 Linear 层ReLU 激活函数以及大多数其他神经网络层和函数都返回 torch.Tensor 类型的对象。这些返回的张量包含了经过相应层或函数处理后的数据。在神经网络中,数据通常以张量的形式在各个层之间流动。

一、Module类介绍

所有神经网络层和模型的基础类,自定义神经网络时对其继承。

1、主要功能

  1. 封装参数

    • Module 类在内部自动管理 层的参数。每当你在 Module 中定义一个层对象,如 self.conv1 = nn.Conv2d(...), PyTorch 自动将这些层的参数加入到模型的参数列表中。这些参数通过 module.parameters() 方法访问。模型参数(定义在模型内部的层的权重和偏置)默认 requires_grad=True
  2. 自动梯度计算

    • 每个 Module 可以使用 PyTorch 的自动微分(autograd)系统来自动计算和存储梯度。在 forward 方法执行运算时,PyTorch 会跟踪这些运算产生的所有张量,对应的梯度在调用张量的 backward()方法后自动计算。由于模型参数默认requires_grad=True,因此对这些参数的所有操作都将被进行自动梯度计算。
  3. 前向传播定义

    • 在定义自己的网络时,需要覆盖 Moduleforward() 方法。这是模型接收输入数据并返回输出的地方forward() 方法定义了模型的前向传播路径
  4. 模型保存和加载
    模型的保存和加载是在 PyTorch 中进行模型持久化和迁移学习的常用操作。模型可以保存为 .pt.pth 文件,包括其参数、优化器状态和其他任何相关的信息。

  • 保存模型:

    • 最简单的保存方法是使用 torch.save 来保存模型的 state_dict,这是一个包含模型参数的字典。
    torch.save(model.state_dict(), 'model_path.pth')
    
  • 加载模型:

    • 加载模型时,首先需要实例化模型对象,然后使用 load_state_dict() 方法加载参数。
    model = MyModel()
    model.load_state_dict(torch.load('model_path.pth'))
    
  1. 将模型移动到指定的设备:
    在 PyTorch 中,可以将模型和数据移动到不同的设备上(如 CPU 或 GPU),以支持不同的计算需求。
  • 使用 .to() 方法可以将模型移动到指定的设备:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
  1. 切换模型的训练和评估方式
    torch.nn.Module 提供了 .train().eval() 方法,用于切换模型的训练和评估模式。
  • 训练模式 (train):

    • 在训练模式下,所有的层都被通知模型正在训练,这对于某些特定层(如 DropoutBatchNorm)非常重要,因为它们在训练和评估时的行为不同。
    model.train()
    
  • 评估模式 (eval):

    • 评估模式用于模型测试或验证阶段,确保所有层都处于评估状态。
    model.eval()
    
  1. .parameters()方法
  • parameters() 方法返回一个迭代器,包含模型中所有的参数(通常用于传递给优化器)。
    for param in model.parameters():
        print(param.size())
    
  1. .modules()方法
  • modules() 方法返回一个迭代器,遍历模型中的所有模块(层)。这在分析模型结构或应用特定操作到每一层时非常有用。
    for module in model.modules():
        print(module)
    

2、神经网络模型使用理解

白话:
  损失函数和优化器都不是module类中的方法,而是外部的方法,但是他们都能够作用于模型的权重:由于自动微分,损失函数接收的是结果张量,因此损失函数带来的梯度会被更新给权重的梯度。而优化器接受的是module对象参数的迭代器,它能根据参数的梯度对参数进行更新。

  自定义神经网络,实际上就是定义一个,该类继承自torch.nn.Module。在对这个类进行实例化时,是使用__init__默认构造函数实例化的。实例化后得到一个神经网络对象,对该对象输入数据会被重载为输入forward函数,而forward函数就是对输入数据进行一层一层的网络层结构处理。forward函数的输出一般是对输入进行了前向传播后的结果,为了对模型参数进行训练更新,我们一般还需要定义一个损失函数;这个损失函数是torch.tensor类型的,可以调用其backward函数,进行反向传播梯度;最后定义一个优化器,进行参数权重更新(实际上这里反向传播梯度 就 相当于损失函数对权重进行求导了,改变权重的方向就是让损失更小的方向。)

反向传播并不更新参数,优化器才是用来更新参数的,反向传播只是更新梯度。 这也是为什么优化器有一个学习率。教程:张量的梯度计算


非白话 自定义神经网络的流程:

  1. 定义一个类,继承自 torch.nn.Module

    • 这个类是您自定义神经网络的基础。通过继承 torch.nn.Module,您的网络能够利用 PyTorch 提供的模块化、参数管理、梯度计算等强大功能。
  2. __init__ 方法中初始化网络层

    • 这是定义神经网络结构的地方。您可以添加诸如全连接层 (nn.Linear), 卷积层 (nn.Conv2d), 激活函数 (nn.ReLU) 等。这些层将被自动注册为模块的子项,使其参数也自动成为模型的一部分。
  3. 定义 forward 方法

    • forward 方法描述了输入数据如何通过定义的层传播。这个方法是在模型训练和评估时自动被调用的,用于前向传播计算输出。
  4. 损失函数和反向传播

    • 在训练阶段,网络输出通过一个损失函数 (loss function) 评估其与真实标签的差异。常用的损失函数有 nn.CrossEntropyLoss(用于分类任务)和 nn.MSELoss(用于回归任务)。
    • 调用损失张量的 .backward() 方法启动自动梯度计算,即反向传播。在这一过程中,PyTorch 根据损失函数自动计算每个参数的梯度,并存储在参数的 .grad 属性中。
    • 在每次迭代后,需要手动清空梯度,以便下一次迭代。如果不清空梯度,梯度会累积,导致不正确的参数更新。清空梯度:optimizer.zero_grad()
  5. 参数更新

    • 使用一个优化器(如 torch.optim.SGDtorch.optim.Adam)来调整网络参数,基于计算的梯度进行更新,以减少损失函数的值。这通常在调用 .backward() 后进行。

a.前向传播示例代码

损失函数和优化器的例子请看:神经网络训练过程代码详解

下面是一个简单的自定义 Module 的例子,定义了一个包含两个全连接层的简单神经网络。

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        # 定义第一个全连接层
        self.fc1 = nn.Linear(16, 12)
        # 定义第二个全连接层
        self.fc2 = nn.Linear(12, 10)
	
    def forward(self, x):
        # 第一个全连接层的激活函数使用ReLU
        x = F.relu(self.fc1(x))
        # 第二个全连接层的输出
        x = self.fc2(x)
        return x

# 实例化网络
net = SimpleNet()#__init__()里并不需要参数。默认构造函数不需要参数,net就是一个实例化对象。
# 创建一些随机输入数据
input = torch.randn(1, 16)

# 通过网络进行前向传播
output = net(input)#实际上直接使用对象名(),重载为:调用forward函数。
#input先经过一个nn.Linear(16,12),然后进行一次relu(),然后经过一个nn.Linear(12,10)

b.关键点

  • 继承:自定义的模型需要继承自 nn.Module
  • 超类初始化:使用 super() 初始化基类,这是在 Python 类中常见的做法,确保正确初始化父类部分。
  • 定义层:在构造函数中定义网络所需的各种层。
  • 前向传播:在 forward 方法中定义数据如何通过网络。