1. 计算图构建过程
x = torch.ones(5, requires_grad=True) # 定义叶子节点,启用梯度跟踪
y = x + 2 # 加法操作,生成中间节点 y
z = y * y * 3 # 平方与乘法操作,生成中间节点 z
out = z.mean() # 标量输出(损失函数)
- 动态计算图构建:
每行代码触发一个操作,PyTorch 动态记录操作依赖关系,生成有向无环图(DAG):
x → (Add) → y → (Pow + Mul) → z → (Mean) → out
节点类型:
- 叶子节点:用户直接创建的
x
(x.is_leaf = True
)。 - 非叶子节点:
y
,z
,out
由运算生成(grad_fn
属性记录操作类型)
- 叶子节点:用户直接创建的
- 梯度跟踪机制:
设置
requires_grad=True
后,所有依赖x
的中间节点自动继承此属性(如y.requires_grad=True
)
2. 反向传播与梯度计算
out.backward() # 触发反向传播
- •反向传播流程:
- 1.从
out
开始反向遍历:因out
是标量(shape=()
),无需额外指定梯度权重
。 - 2.
链式法则应用:
out = z.mean()
→ ∂zi∂out=51(z
有 5 个元素)。z = 3y^2
→ ∂yi∂zi=6yi。y = x + 2
→ ∂xi∂yi=1
- 3.梯度计算:
∂xi∂out=∂zi∂out⋅∂yi∂zi⋅∂xi∂yi=51⋅6yi⋅1=56(xi+2)。
- 1.从
- •梯度存储:
结果存入叶子节点
x.grad
,非叶子节点(如y
,z
)的梯度默认不保留以节省内存
3. 梯度结果验证
print(f"x 的梯度: {x.grad}") # 输出:tensor([3.6000, 3.6000, 3.6000, 3.6000, 3.6000])
- •数学推导:
代入 xi=1:
∂xi∂out=56(1+2)=518=3.6。
与代码输出一致,验证了链式法则的正确性
4. 梯度累积问题
- •默认行为:
backward()
计算的梯度会累加到x.grad
。若多次执行out.backward()
,梯度将叠加(如运行两次后x.grad
变为[7.2, 7.2, ...]
) - 解决方案:
训练循环中需在每次反向传播前调用
x.grad.zero_()
或optimizer.zero_grad()
清零梯度
关键概念总结
概念 |
说明 |
代码示例 |
---|---|---|
叶子节点 |
用户直接创建的张量,梯度计算终点 |
|
动态计算图 |
运行时动态构建的操作依赖图,反向传播后自动释放 |
|
非标量反向传播 |
若 |
|
梯度保留 |
设置 |
|
提示:理解计算图结构是调试自动求导的关键。可通过 print(y.grad_fn)
查看操作类型(如输出 <AddBackward0>
),或使用 torchviz
库可视化计算图