PyTorch 中torch.einsum函数的使用详解和工程应用示例

发布于:2025-06-12 ⋅ 阅读:(25) ⋅ 点赞:(0)

torch.einsum 是 PyTorch 中非常强大的函数,用于 灵活定义张量的乘法、求和、转置、点乘、外积等各种线性代数操作。它源自 Einstein Summation(爱因斯坦求和约定),是一种更紧凑、可读性更强的多维操作方式。


函数原型

torch.einsum(equation, *operands)
  • equation: 字符串,定义操作规则。
  • operands: 一个或多个 Tensor,参与运算的张量。

爱因斯坦求和规则简要

equation 中:

  • 字母表示张量的维度;
  • 重复的字母表示该维度将被求和;
  • 不重复的字母表示结果张量保留的维度。

常见用途与示例

1. 向量内积(dot product)

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([0.1, 0.2, 0.3])
res = torch.einsum('i,i->', a, b)
print(res)  # 输出:1.4

解释:i,i-> 表示对所有 i 求乘积再求和。


2. 矩阵乘法(matrix multiplication)

A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = torch.einsum('ik,kj->ij', A, B)
print(C.shape)  # torch.Size([2, 4])

相当于 torch.matmul(A, B)


3. 矩阵向量乘法

A = torch.randn(3, 4)
x = torch.randn(4)
y = torch.einsum('ij,j->i', A, x)
print(y.shape)  # torch.Size([3])

4. 批量矩阵乘法

A = torch.randn(5, 2, 3)  # batch_size=5
B = torch.randn(5, 3, 4)
C = torch.einsum('bij,bjk->bik', A, B)
print(C.shape)  # torch.Size([5, 2, 4])

5. 转置(矩阵维度调换)

A = torch.randn(3, 4)
B = torch.einsum('ij->ji', A)
print(B.shape)  # torch.Size([4, 3])

6. 外积(outer product)

a = torch.tensor([1.0, 2.0])
b = torch.tensor([10.0, 20.0, 30.0])
res = torch.einsum('i,j->ij', a, b)
print(res)
# 输出形状为 (2, 3)

7. 计算每个 batch 的向量 L2 范数平方

x = torch.randn(32, 128)  # batch_size=32, dim=128
norm_sq = torch.einsum('bi,bi->b', x, x)

8. 计算注意力权重矩阵

q = torch.randn(32, 8, 64)  # query: batch, head, dim
k = torch.randn(32, 8, 64)
attn_scores = torch.einsum('bhd,bhd->bh', q, k)

工程应用示例

图邻接矩阵传播(Graph Adjacency Propagation)通常用于图神经网络(GNN)中,表示将节点特征通过邻接结构进行消息传递(message passing)或特征聚合(feature aggregation)的过程。
我们将使用 torch.einsum 实现这个过程。


1. 图邻接传播的数学表达

设:

  • A ∈ R N × N A \in \mathbb{R}^{N \times N} ARN×N:邻接矩阵(可选归一化)
  • X ∈ R N × F X \in \mathbb{R}^{N \times F} XRN×F:节点特征矩阵(N 个节点,每个有 F 维特征)
  • A X AX AX:每个节点从邻居节点聚合特征

2. PyTorch 示例代码(einsum 实现)

import torch

# 假设图有 4 个节点,每个节点有 3 维特征
X = torch.tensor([[1.0, 0.5, 2.0],
                  [0.3, 1.2, 0.7],
                  [0.8, 0.1, 1.1],
                  [0.0, 0.3, 0.4]])  # (4, 3)

# 邻接矩阵 A(可为稀疏或归一化矩阵)
A = torch.tensor([[1, 1, 0, 0],
                  [1, 1, 1, 0],
                  [0, 1, 1, 1],
                  [0, 0, 1, 1]], dtype=torch.float32)  # (4, 4)

# 使用 torch.einsum 进行图传播 AX
# 'ij,jk->ik':A(i,j) * X(j,k) → 输出 (i,k)
X_agg = torch.einsum('ij,jk->ik', A, X)

print("聚合后的特征:")
print(X_agg)

3. 加权传播(带权重)

若有可学习的线性层 W ∈ R F × F ′ W \in \mathbb{R}^{F \times F'} WRF×F,传播过程变为:

A X W AXW AXW

代码示例:

# W 是可学习的线性变换
W = torch.nn.Linear(in_features=3, out_features=2, bias=False)

# XW: 节点特征线性变换
X_transformed = W(X)  # (4, 2)

# 再传播
X_out = torch.einsum('ij,jk->ik', A, X_transformed)

print("传播后的新特征维度:", X_out.shape)

4. 扩展到 Batch 形式(多个图)

设:

  • X: (B, N, F) — batch_size 个图,每图 N 节点 F 维特征
  • A: (B, N, N) — batch 的邻接矩阵
# B 个图,每图 4 节点,每个节点 3 维特征
X = torch.randn(8, 4, 3)       # (B, N, F)
A = torch.eye(4).repeat(8, 1, 1)  # (B, N, N)

# 批量邻接传播
X_agg = torch.einsum('bij,bjf->bif', A, X)  # 输出 (B, N, F)

5. 总结

操作 einsum 表达 说明
图传播(AX) 'ij,jk->ik' 基础图邻接传播
批量传播 'bij,bjf->bif' 批次图传播
权重传播 AX @ W or 'ij,jk->ik' then Linear 加入特征变换

为什么用 einsum

  • 替代嵌套的 permute + view + matmul
  • 更接近数学表达;
  • 性能有时更优;
  • 可用于写清晰的复杂操作,如 self-attention、卷积、图神经网络等。

小技巧

  • einsum_path 可用于优化路径选择:
torch.einsum_path('bij,bjk->bik', A, B, optimize='optimal')

总结

功能 示例公式
向量点积 'i,i->'
矩阵乘法 'ik,kj->ij'
外积 'i,j->ij'
转置 'ij->ji'
batch matmul 'bij,bjk->bik'
L2 norm square 'bi,bi->b'

网站公告

今日签到

点亮在社区的每一天
去签到