AF3 Vec3Array类源码解读

发布于:2025-02-10 ⋅ 阅读:(69) ⋅ 点赞:(0)

Vec3Array类以及相关的运算定义在AlphaFold3的src.utils.geometry.vector模块中,Vec3Array类实现了一个用于处理三维向量数组的结构,支持基本的数学操作(加减乘除),以及常用的几何运算(点积、叉积、归一化等)。

源代码:

"""Vec3Array Class."""

from __future__ import annotations
import dataclasses
from typing import Union, List, Optional

import torch

Float = Union[float, torch.Tensor]


@dataclasses.dataclass(frozen=True)
class Vec3Array:
    x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
    y: torch.Tensor
    z: torch.Tensor

    def __post_init__(self):
        if hasattr(self.x, 'dtype'):
            assert self.x.dtype == self.y.dtype
            assert self.x.dtype == self.z.dtype
            assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
            assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])

    def __add__(self, other: Vec3Array) -> Vec3Array:
        return Vec3Array(
            self.x + other.x,
            self.y + other.y,
            self.z + other.z,
        )

    def __sub__(self, other: Vec3Array) -> Vec3Array:
        return Vec3Array(
            self.x - other.x,
            self.y - other.y,
            self.z - other.z,
        )

    def __mul__(self, other: Float) -> Vec3Array:
        return Vec3Array(
            self.x * other,
            self.y * other,
            self.z * other,
        )

    def __rmul__(self, other: Float) -> Vec3Array:
        return self * other

    def __truediv__(self, other: Float) -> Vec3Array:
        return Vec3Array(
            self.x / other,
            self.y / other,
            self.z / other,
        )

    def __neg__(self) -> Vec3Array:
        return self * -1

    def __pos__(self) -> Vec3Array:
        return self * 1

    def __getitem__(self, index) -> Vec3Array:
        return Vec3Array(
            self.x[index],
            self.y[index],
            self.z[index],
        )

    def __iter__(self):
        return iter((self.x, self.y, self.z))

    @property
    def shape(self):
        return self.x.shape

    def map_tensor_fn(self, fn) -> Vec3Array:
        return Vec3Array(
            fn(self.x),
            fn(self.y),
            fn(self.z),
        )

    def cross(self, other: Vec3Array) -> Vec3Array:
        """Compute cross product between 'self' and 'other'."""
        new_x = self.y * other.z - self.z * other.y
        new_y = self.z * other.x - self.x * other.z
        new_z