AF3 rot_matmul 和 rot_vec_mul函数解读

发布于:2025-03-26 ⋅ 阅读:(72) ⋅ 点赞:(0)

AlphaFold3  rigid_utils 模块的 rot_matmul 和 rot_vec_mul 函数实现了手动计算两个旋转矩阵的乘法 A×B 以及矩阵-向量乘法 R×t, 避免了直接用矩阵乘法的AMP(Automatic Mixed Precision)问题。

源代码:

def rot_matmul(
    a: torch.Tensor, 
    b: torch.Tensor
) -> torch.Tensor:
    """
        Performs matrix multiplication of two rotation matrix tensors. Written
        out by hand to avoid AMP downcasting.

        Args:
            a: [*, 3, 3] left multiplicand
            b: [*, 3, 3] right multiplicand
        Returns:
            The product ab
    """
    def row_mul(i):
        return torch.stack(
            [
                a[..., i, 0] * b[..., 0, 0]
                + a[..., i, 1] * b[..., 1, 0]
                + a[..., i, 2] * b[..., 2, 0],
                a[..., i, 0] * b[..., 0, 1]
                + a[..., i, 1] * b[..., 1, 1]
                + a[..., i, 2] * b[..., 2, 1],
                a[..., i, 0] * b[..., 0, 2]
                + a[..., i, 1] * b[..., 1, 2]
                + a[..., i, 2] * b[..., 2, 2],
            ],
            dim=-1,
        )

    return torch.stack(
        [
            row_mul(0), 
            row_mul(1), 
            row_mul(2),
        ], 
        dim=-2
    )


def rot_vec_mul(
    r: torch.Tensor, 
    t: torch.Tensor
) -> torch.Tensor:
    """
        Applies a rotation to a vector. Written out by hand to avoid transfer
        to avoid AMP downcasting.

        Args:
            r: [*, 3, 3] rotation matrices
            t: [*, 3] coordinate tensors
        Returns:
            [*, 3] rotated coordinates
    """
    x, y, z = torch.unbind(t, dim=-1)
    return torch.stack(
        [
            r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
            r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
            r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
        ],
        dim=-1,
    )

代码解读:

def rot_matmul(
    a: torch.Tensor, 
    b: torch.Tensor
) -> torch.Tensor:


网站公告

今日签到

点亮在社区的每一天
去签到