DAY02:【pytorch】计算图与动态图机制

发布于:2025-04-13 ⋅ 阅读:(71) ⋅ 点赞:(0)

一、引言

在深度学习框架中,计算图是理解自动求导和模型优化的核心概念。无论是PyTorch的动态图机制,还是TensorFlow早期的静态图模式,计算图都扮演着关键角色。本文将深入解析计算图的基本原理,并结合PyTorch代码演示动态图的运行机制,帮助读者建立从理论到实践的完整认知

二、计算图基础:用图结构描述运算

计算图(Computational Graph)是一种用于描述数学运算的有向无环图(DAG),包含两个核心元素:

  • 结点(Node):表示数据,如向量、矩阵、张量(Tensor)等
  • 边(Edge):表示运算,如加减乘除、卷积、矩阵乘法等

以公式 y = ( x + w ) × ( w + 1 ) y = (x + w) \times (w + 1) y=(x+w)×(w+1) 为例,其计算图构建过程如下:

  1. 定义叶子结点 x x x w w w(用户创建的初始数据)
  2. 中间结点 a = x + w a = x + w a=x+w(加法运算)
  3. 中间结点 b = w + 1 b = w + 1 b=w+1(加法运算)
  4. 输出结点 y = a × b y = a \times b y=a×b(乘法运算)

用图形表示为:

x ── + ──> a ──┐
     │         ├─> y
w ── + ──> b ──┘
     │
     └─> 1

三、计算图与梯度求导:自动求导的核心逻辑

在深度学习中,梯度求导是优化模型参数的关键步骤。计算图通过**反向传播(Backpropagation)**实现梯度的高效计算,核心原理是链式法则

1、叶子结点与非叶子结点

  • 叶子结点(Leaf Node):用户直接创建的结点(如 x x x w w w),其 is_leaf 属性为 True
  • 非叶子结点:由运算生成的中间结点(如 a a a b b b y y y),其 is_leaf 属性为 False

2、梯度与梯度函数

  • 梯度(grad):存储结点关于最终输出的导数。叶子结点的梯度在反向传播后被填充,非叶子结点的梯度默认不保留(除非调用 retain_grad()
  • 梯度函数(grad_fn):记录创建该结点时使用的运算函数,用于反向传播时计算梯度。例如:
    • 加法运算的梯度函数为 <AddBackward0>
    • 乘法运算的梯度函数为 <MulBackward0>

3. 链式法则示例

对于 y = ( x + w ) × ( w + 1 ) y = (x + w) \times (w + 1) y=(x+w)×(w+1),计算 ∂ y ∂ w \frac{\partial y}{\partial w} wy

∂ y ∂ w = ∂ y ∂ a ⋅ ∂ a ∂ w + ∂ y ∂ b ⋅ ∂ b ∂ w = b ⋅ 1 + a ⋅ 1 = ( w + 1 ) + ( x + w ) \frac{\partial y}{\partial w} = \frac{\partial y}{\partial a} \cdot \frac{\partial a}{\partial w} + \frac{\partial y}{\partial b} \cdot \frac{\partial b}{\partial w} = b \cdot 1 + a \cdot 1 = (w + 1) + (x + w) wy=aywa+bywb=b1+a1=(w+1)+(x+w)

x = 2 x=2 x=2 w = 1 w=1 w=1 时,梯度为 ( 1 + 1 ) + ( 2 + 1 ) = 5 (1+1)+(2+1)=5 (1+1)+(2+1)=5

四、PyTorch动态图机制:灵活高效的计算模式

根据计算图的搭建方式,可分为两种模式:

动态图(Dynamic Graph) 静态图(Static Graph)
运算与图搭建同时进行 先搭建图,后执行运算
灵活易调节(支持条件语句) 高效但缺乏灵活性
代表框架:PyTorch 代表框架:TensorFlow 1.x

1、动态图的核心优势

  1. 即时执行:每一步运算立即生效,方便调试和交互式开发
  2. 动态控制流:支持循环、条件判断等Python原生控制逻辑,模型结构可动态调整

2、PyTorch动态图示例

以下代码展示了动态图的构建与反向传播过程(对应文档中的公式 y = ( x + w ) × ( w + 1 ) y = (x + w) \times (w + 1) y=(x+w)×(w+1)):

import torch

# 创建叶子结点(requires_grad=True表示需要计算梯度)
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

# 中间结点运算
a = torch.add(w, x)  # a = w + x
a.retain_grad()      # 保留中间变量a的梯度(非叶子结点默认不保留)
b = torch.add(w, 1)  # b = w + 1
y = torch.mul(a, b)  # y = a * b

# 反向传播计算梯度
y.backward()

# 查看叶子结点状态
print("is_leaf:", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
# 输出:is_leaf: True True False False False

# 查看梯度(w的梯度应为5,x的梯度为b=2)
print("grad:", w.grad, x.grad, a.grad, b.grad, y.grad)
# 输出:grad: tensor([5.]) tensor([2.]) tensor([2.]) tensor([3.]) None

# 查看梯度函数(记录运算类型)
print("grad_fn:", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)
# 输出:grad_fn: None None <AddBackward0 object at ...> <AddBackward0 object at ...> <MulBackward0 object at ...>

代码解析

  1. 叶子结点wxis_leafTrue,其余结点为 False
  2. 梯度计算
    • w.grad5,对应公式推导结果
    • x.gradb=2(因为 ∂ y ∂ x = ∂ y ∂ a ⋅ ∂ a ∂ x = b ⋅ 1 = 2 \frac{\partial y}{\partial x} = \frac{\partial y}{\partial a} \cdot \frac{\partial a}{\partial x} = b \cdot 1 = 2 xy=ayxa=b1=2
  3. 梯度函数
    • a.grad_fnb.grad_fn 均为加法运算的反向函数 <AddBackward0>
    • y.grad_fn 为乘法运算的反向函数 <MulBackward0>

微语录:尽最大的努力,留最小的遗憾。