torch.einsum
是 PyTorch 中非常强大的函数,用于 灵活定义张量的乘法、求和、转置、点乘、外积等各种线性代数操作。它源自 Einstein Summation(爱因斯坦求和约定),是一种更紧凑、可读性更强的多维操作方式。
函数原型
torch.einsum(equation, *operands)
equation
: 字符串,定义操作规则。operands
: 一个或多个 Tensor,参与运算的张量。
爱因斯坦求和规则简要
在 equation
中:
- 字母表示张量的维度;
- 重复的字母表示该维度将被求和;
- 不重复的字母表示结果张量保留的维度。
常见用途与示例
1. 向量内积(dot product)
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([0.1, 0.2, 0.3])
res = torch.einsum('i,i->', a, b)
print(res) # 输出:1.4
解释:i,i->
表示对所有 i
求乘积再求和。
2. 矩阵乘法(matrix multiplication)
A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = torch.einsum('ik,kj->ij', A, B)
print(C.shape) # torch.Size([2, 4])
相当于 torch.matmul(A, B)
3. 矩阵向量乘法
A = torch.randn(3, 4)
x = torch.randn(4)
y = torch.einsum('ij,j->i', A, x)
print(y.shape) # torch.Size([3])
4. 批量矩阵乘法
A = torch.randn(5, 2, 3) # batch_size=5
B = torch.randn(5, 3, 4)
C = torch.einsum('bij,bjk->bik', A, B)
print(C.shape) # torch.Size([5, 2, 4])
5. 转置(矩阵维度调换)
A = torch.randn(3, 4)
B = torch.einsum('ij->ji', A)
print(B.shape) # torch.Size([4, 3])
6. 外积(outer product)
a = torch.tensor([1.0, 2.0])
b = torch.tensor([10.0, 20.0, 30.0])
res = torch.einsum('i,j->ij', a, b)
print(res)
# 输出形状为 (2, 3)
7. 计算每个 batch 的向量 L2 范数平方
x = torch.randn(32, 128) # batch_size=32, dim=128
norm_sq = torch.einsum('bi,bi->b', x, x)
8. 计算注意力权重矩阵
q = torch.randn(32, 8, 64) # query: batch, head, dim
k = torch.randn(32, 8, 64)
attn_scores = torch.einsum('bhd,bhd->bh', q, k)
工程应用示例
图邻接矩阵传播(Graph Adjacency Propagation)通常用于图神经网络(GNN)中,表示将节点特征通过邻接结构进行消息传递(message passing)或特征聚合(feature aggregation)的过程。
我们将使用 torch.einsum
实现这个过程。
1. 图邻接传播的数学表达
设:
- A ∈ R N × N A \in \mathbb{R}^{N \times N} A∈RN×N:邻接矩阵(可选归一化)
- X ∈ R N × F X \in \mathbb{R}^{N \times F} X∈RN×F:节点特征矩阵(N 个节点,每个有 F 维特征)
- A X AX AX:每个节点从邻居节点聚合特征
2. PyTorch 示例代码(einsum 实现)
import torch
# 假设图有 4 个节点,每个节点有 3 维特征
X = torch.tensor([[1.0, 0.5, 2.0],
[0.3, 1.2, 0.7],
[0.8, 0.1, 1.1],
[0.0, 0.3, 0.4]]) # (4, 3)
# 邻接矩阵 A(可为稀疏或归一化矩阵)
A = torch.tensor([[1, 1, 0, 0],
[1, 1, 1, 0],
[0, 1, 1, 1],
[0, 0, 1, 1]], dtype=torch.float32) # (4, 4)
# 使用 torch.einsum 进行图传播 AX
# 'ij,jk->ik':A(i,j) * X(j,k) → 输出 (i,k)
X_agg = torch.einsum('ij,jk->ik', A, X)
print("聚合后的特征:")
print(X_agg)
3. 加权传播(带权重)
若有可学习的线性层 W ∈ R F × F ′ W \in \mathbb{R}^{F \times F'} W∈RF×F′,传播过程变为:
A X W AXW AXW
代码示例:
# W 是可学习的线性变换
W = torch.nn.Linear(in_features=3, out_features=2, bias=False)
# XW: 节点特征线性变换
X_transformed = W(X) # (4, 2)
# 再传播
X_out = torch.einsum('ij,jk->ik', A, X_transformed)
print("传播后的新特征维度:", X_out.shape)
4. 扩展到 Batch 形式(多个图)
设:
X
: (B, N, F) — batch_size 个图,每图 N 节点 F 维特征A
: (B, N, N) — batch 的邻接矩阵
# B 个图,每图 4 节点,每个节点 3 维特征
X = torch.randn(8, 4, 3) # (B, N, F)
A = torch.eye(4).repeat(8, 1, 1) # (B, N, N)
# 批量邻接传播
X_agg = torch.einsum('bij,bjf->bif', A, X) # 输出 (B, N, F)
5. 总结
操作 | einsum 表达 |
说明 |
---|---|---|
图传播(AX) | 'ij,jk->ik' |
基础图邻接传播 |
批量传播 | 'bij,bjf->bif' |
批次图传播 |
权重传播 | AX @ W or 'ij,jk->ik' then Linear |
加入特征变换 |
为什么用 einsum
?
- 替代嵌套的
permute
+view
+matmul
; - 更接近数学表达;
- 性能有时更优;
- 可用于写清晰的复杂操作,如 self-attention、卷积、图神经网络等。
小技巧
einsum_path
可用于优化路径选择:
torch.einsum_path('bij,bjk->bik', A, B, optimize='optimal')
总结
功能 | 示例公式 |
---|---|
向量点积 | 'i,i->' |
矩阵乘法 | 'ik,kj->ij' |
外积 | 'i,j->ij' |
转置 | 'ij->ji' |
batch matmul | 'bij,bjk->bik' |
L2 norm square | 'bi,bi->b' |