【学习笔记】反向传播到底是如何进行的?

发布于:2024-12-18 ⋅ 阅读:(76) ⋅ 点赞:(0)

一、写在前面

不知道小伙伴们有没有考虑过这种感觉,在最开始学习深度学习的时候,一定都了解过前向传播,反向传播等等,但是在实际的操作过程中却“几乎用不到”,那么反向传播过程在代码中到底是如何进行的呢?今天让我们来回顾一下。

跑过深度学习代码的小伙伴们想必都见过下面这几行:

self.opt.zero_grad()
loss = nn.MSELoss()(var_pred, var_true)
loss.backward()
self.opt.step()

接下来我们就重点分析这几行代码的作用!

  • self.opt.zero_grad()

进行梯度清零,至于为什么要进行梯度清零,可以先看下面。

  • loss = nn.MSELoss()(var_pred, var_true)

计算损失。

  • loss.backward()

回传损失,计算梯度。

  • self.opt.step()

更新权重。

二、举个例子

我们可以举一个简单的全连接层的例子,下面是具体的代码。

   import torch.nn as nn

   # 定义模型
   class SimpleModel(nn.Module):
       def __init__(self):
           super(SimpleModel, self).__init__()
           self.fc = nn.Linear(1, 1)  # 全连接层

       def forward(self, x):
           return self.fc(x)
   # 创建模型实例
   model = SimpleModel()

   # 定义损失函数
   criterion = nn.MSELoss()

   # 定义优化器
   optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

   # 输入数据
   input_data = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
   target = torch.tensor([[1.5], [2.5], [3.5]])
   for name, param in model.named_parameters():
       print(name, param)
   print()

   # 前向传播
   output = model(input_data)
   print("output:", output)
   # 计算损失
   loss = criterion(output, target)
   print(input_data.grad)
   # 反向传播 计算梯度
   loss.backward()
   print("loss:", loss.item())
   print(input_data.grad)

   for name, param in model.named_parameters():
       print(name, param, param.grad)
   print()

   # 权重更新
   optimizer.step()
   for name, param in model.named_parameters():
       print(name, param, param.grad)
   print()

在上面的例子中,我们定义的神经网络的数学公式为:
y = w x + b y=wx+b y=wx+b

那么这个神经网络的权重是如何更新的呢?

  1. 权重指的是什么?

在本例中就是w和b,也就是除了输入和输出以外的其他变量。

  1. 在我们创建模型的时候会自动初始化一个模型的权重;

在本例中,我们可以打印出来初始的权重:
fc.weight Parameter containing:
tensor([[-0.2505]], requires_grad=True)
fc.bias Parameter containing:
tensor([0.9166], requires_grad=True)
也就是说初始的w=-0.2505,b=0.9166;

  1. 前向传播,计算输出:

根据公式y=f(x),在本例中f(x)=wx+b,换成其他的神经网络也类似计算。
带入x可以获得输出:
output: tensor([[0.6661], [0.4156], [0.1651]],
grad_fn=<AddmmBackward0>)

  1. 计算损失:

本例中采用的损失函数为MSE Loss,计算公式如下:
L = 1 N ∑ i = 1 i = N ( y i − y i ^ ) 2 L=\frac{1}{N}\sum_{i=1}^{i=N}(y_i-\hat{y_i} ) ^2 L=N1i=1i=N(yiyi^)2
其中 y y y代表输出, y ^ \hat{y} y^代表真实值,带入公式计算可以获取损失值为:

注意此时只进行了前向传播,还没有进行反向传播,所以此时还未进行梯度计算,所以现在权重的梯度为None。

  1. 反向传播,计算梯度。

梯度计算公式为:
∂ L ∂ w = ∂ L ∂ y ∂ y ∂ w \frac{\partial L}{\partial w} =\frac{\partial L}{\partial y}\frac{\partial y}{\partial w} wL=yLwy
根据本例:
∂ L ∂ y = 2 N ∑ i = 1 i = N ( y i − y i ^ ) \frac{\partial L}{\partial y}=\frac{2}{N}\sum_{i=1}^{i=N}(y_i-\hat{y_i} ) yL=N2i=1i=N(yiyi^),其中 y ^ \hat{y} y^是常量;

∂ y ∂ w = x \frac{\partial y}{\partial w}=x wy=x

所以:
∂ L ∂ w = ∂ L ∂ y ∂ y ∂ w = 2 N ∑ i = 1 i = N ( y i − y i ^ ) x i \frac{\partial L}{\partial w} =\frac{\partial L}{\partial y}\frac{\partial y}{\partial w}=\frac{2}{N}\sum_{i=1}^{i=N}(y_i-\hat{y_i} )x_i wL=yLwy=N2i=1i=N(yiyi^)xi
带入值可以求得: ∂ L ∂ w = − 10.0049 \frac{\partial L}{\partial w} =-10.0049 wL=10.0049
同理:
∂ L ∂ b = ∂ L ∂ y ∂ y ∂ b \frac{\partial L}{\partial b} =\frac{\partial L}{\partial y}\frac{\partial y}{\partial b} bL=yLby
∂ y ∂ b = 1 \frac{\partial y}{\partial b}=1 by=1
所以:
∂ L ∂ b = ∂ L ∂ y ∂ y ∂ b = 2 N ∑ i = 1 i = N ( y i − y i ^ ) \frac{\partial L}{\partial b} =\frac{\partial L}{\partial y}\frac{\partial y}{\partial b}=\frac{2}{N}\sum_{i=1}^{i=N}(y_i-\hat{y_i} ) bL=yLby=N2i=1i=N(yiyi^)
带入值可以求得: ∂ L ∂ b = − 4.1688 \frac{\partial L}{\partial b} =-4.1688 bL=4.1688

  1. 更新权重

w = w − η ∂ y ∂ w w=w-\eta \frac{\partial y}{\partial w} w=wηwy
b = b − η ∂ y ∂ b b=b-\eta \frac{\partial y}{\partial b} b=bηby
其中 η \eta η 是 学习率;
带入值得:

上面就是正常梯度的更新过程,所以为什么每次都要先进行梯度清零现在应该清楚了吧,如果不进行清零,最后在更新权重的时候梯度会有累加。

三、混合损失函数如何进行呢?

  • 不同损失函数之间的混合:
    例如:L1 = MSE,L2 = MAE, L=L1+L2

    1. 分别计算损失,L1, L2,然后求和 L=L1+L2;
    2. 分别计算梯度,然后将梯度求和,
    3. 更新权重。
      这种混合方式看哪个损失函数计算的损失大,那么哪个损失函数对权重更新的影响就大。
  • 相同损失函数,但是对不同输出进行:
    例如:L1=MSE(output,target), L2 =MSE(output.sum(), target.sum())
    这种与之前的类似:

    1. 分别计算损失,L1, L2,然后求和 L=L1+L2;
    2. 分别计算梯度,然后将梯度求和,
    3. 更新权重。
      这种计算方式同样也是看哪种计算方式获取的损失大,哪种对权重的更新影响就大。

网站公告

今日签到

点亮在社区的每一天
去签到