PyTorch自动求导

发布于:2025-08-20 ⋅ 阅读:(14) ⋅ 点赞:(0)

1. 计算图构建过程

x = torch.ones(5, requires_grad=True)  # 定义叶子节点,启用梯度跟踪
y = x + 2                             # 加法操作,生成中间节点 y
z = y * y * 3                         # 平方与乘法操作,生成中间节点 z
out = z.mean()                        # 标量输出(损失函数)
  • 动态计算图构建​:

    每行代码触发一个操作,PyTorch 动态记录操作依赖关系,生成有向无环图(DAG):

    x → (Add) → y → (Pow + Mul) → z → (Mean) → out

    节点类型:

    • 叶子节点​:用户直接创建的 xx.is_leaf = True)。
    • 非叶子节点​:y, z, out由运算生成(grad_fn属性记录操作类型)
  • 梯度跟踪机制​:

    设置 requires_grad=True后,所有依赖 x的中间节点自动继承此属性(如 y.requires_grad=True


2. 反向传播与梯度计算

out.backward()  # 触发反向传播
  • 反向传播流程​:
    1. 1.out开始反向遍历​:因 out是标量(shape=()),无需额外指定梯度权重
    2. 2.

      链式法则应用​:

      • out = z.mean()→ ∂zi​∂out​=51​(z有 5 个元素)。
      • z = 3y^2→ ∂yi​∂zi​​=6yi​。
      • y = x + 2→ ∂xi​∂yi​​=1
    3. 3.​梯度计算​:

      ∂xi​∂out​=∂zi​∂out​⋅∂yi​∂zi​​⋅∂xi​∂yi​​=51​⋅6yi​⋅1=56​(xi​+2)。

  • •​梯度存储​:

    结果存入叶子节点 x.grad,非叶子节点(如 y, z)的梯度默认不保留以节省内存


3. 梯度结果验证

print(f"x 的梯度: {x.grad}")  # 输出:tensor([3.6000, 3.6000, 3.6000, 3.6000, 3.6000])
  • •​数学推导​:

    代入 xi​=1:

    ∂xi​∂out​=56​(1+2)=518​=3.6。

    与代码输出一致,验证了链式法则的正确性


4. 梯度累积问题

  • •​默认行为​:

    backward()计算的梯度会累加x.grad。若多次执行 out.backward(),梯度将叠加(如运行两次后 x.grad变为 [7.2, 7.2, ...]

  • 解决方案​:

    训练循环中需在每次反向传播前调用 x.grad.zero_()optimizer.zero_grad()清零梯度


关键概念总结

概念

说明

代码示例

叶子节点

用户直接创建的张量,梯度计算终点

x = torch.ones(5, requires_grad=True)

动态计算图

运行时动态构建的操作依赖图,反向传播后自动释放

y = x + 2生成 AddBackward节点

非标量反向传播

out非标量(如向量),需传入 gradient参数作为权重矩阵

z.backward(torch.ones_like(z))

梯度保留

设置 retain_graph=True可保留计算图,支持多次反向传播

out.backward(retain_graph=True)


提示​:理解计算图结构是调试自动求导的关键。可通过 print(y.grad_fn)查看操作类型(如输出 <AddBackward0>),或使用 torchviz库可视化计算图


网站公告

今日签到

点亮在社区的每一天
去签到