AF3 DiffusionModule类解读

发布于:2025-02-10 ⋅ 阅读:(68) ⋅ 点赞:(0)

DiffusionModule类是 AlphaFold3 中的一个关键组件,基于扩散模型(diffusion models)来进行蛋白质结构预测。

源代码:

"""
Diffusion Module from AlphaFold3.
The StructureModule of AlphaFold2 using invariant point attention was replaced with a relatively standard
non-equivariant point-cloud diffusion model over all atoms. The denoiser is based on a modern transformer,
but with several modifications to make it more amenable to the task. The main changes are:
 - Conditioning from the trunk in several ways: initialise the activations for the single embedding, use
    a variant of Adaptive Layernorm for the single conditioning and logit biasing for the pair conditioning.
 - Standard transformer tricks (e.g. SwiGLU) and methods used in AlphaFold2 (gating)
 - A two-level architecture, working first on atoms, then tokens, then atoms again.
"""
import math
import torch
from torch import nn
from torch import Tensor
from torch.nn import LayerNorm
from typing import Dict, Tuple
from src.models.diffusion_conditioning import DiffusionConditioning
from src.models.diffusion_transformer import DiffusionTransformer
from src.models.components.atom_attention import AtomAttentionEncoder, AtomAttentionDecoder
from src.models.components.primitives import LinearNoBias
from src.utils.geometry.vector import Vec3Array
from src.diffusion.augmentation import centre_random_augmentation
from src.diffusion.noise import sample_noise_level, noise_positions


class DiffusionModule(torch.nn.Module):
    def __init__(
            self,
            c_atom: int = 128,
            c_atompair: int = 16,
            c_token: int = 768,
            c_tokenpair: int = 128,
            atom_encoder_blocks: int = 3,
            atom_encoder_heads: int = 16,
            dropout: float = 0.0,
            atom_attention_n_queries: int = 32,
            atom_attention_n_keys: int = 128,
            atom_decoder_blocks: int = 3,
            atom_decoder_heads: int = 16,
            token_transformer_blocks: int = 24,
            token_transformer_heads: int = 16,
            sd_data: float = 16.0,
            s_max: float = 160.0,
            s_min: float = 4e-4,
            p: float = 7.0,
            clear_cache_between_blocks: bool = False,
            blocks_per_ckpt: int = 1,
    ):
        super(DiffusionModule, self).__init__()
        self.c_atom = c_atom
        self.c_atompair = c_atompair
        self.c_token = c_token
        self.c_tokenpair = c_tokenpair
        self.atom_encoder_blocks = atom_encoder_blocks
        self.atom_encoder_heads = atom_encoder_heads
        self.dropout = dropout
        self.atom_attention_n_queries = atom_attention_n_queries
        self.atom_attention_n_keys = atom_attention_n_keys
        self.token_transformer_blocks = token_transformer_blocks
        self.token_transformer_heads = token_transformer_heads
        self.sd_data = sd_data
        self.s_max = s_max
        self.s_min = s_min
        self.p = p
        self.clear_cache_between_blocks = clear_cache_between_blocks
        self.blocks_per_ckpt = blocks_per_ckpt

        # Conditioning
        self.diffusion_conditioning = DiffusionConditioning(
            c_token=c_token,
            c_pair=c_tokenpair,
            sd_data=sd_data
        )

        # Sequence-local atom attention and aggregation to coarse-grained tokens
        self.atom_attention_encoder = AtomAttentionEncoder(
            c_token=c_token,
            c_atom=c_atom,
            c_atompair=c_atompair,
            c_trunk_pair=c_tokenpair,
            no_blocks=atom_decoder_blocks,
            no_heads=atom_encoder_heads,
            dropout=dropout,
            n_queries=atom_attention_n_queries,
            n_keys=atom_attention_n_keys,
            trunk_conditioning=True,
            clear_cache_between_blocks=clear_cache_between_blocks
        )

        # Full self-attention on token level
        self.token_proj = nn.Sequential(
            LayerNorm(c_token),
            LinearNoBias(c_token, c_token, init='final')
        )
        self.diffusion_transformer = DiffusionTransformer(
            c_token=c_token,
            c_pair=c_tokenpair,
            no_blocks=token_transformer_blocks,
            no_heads=token_transformer_heads,
            dropout=dropout,
            clear_cache_between_blocks=clear_cache_between_blocks,
            blocks_per_ckpt=blocks_per_ckpt,
        )
        self.token_post_layer_norm = LayerNorm(c_token)

        # Broadcast token activations to atoms and run sequence-local atom attention
        self.atom_attention_decoder = AtomAttentionDecoder(
            c_token=c_token,
            c_atom=c_atom,
            c_atompair=c_atompair,
            no_blocks=atom_decoder_blocks,
            no_heads=atom_decoder_heads,
            dropout=dropout,
            n_queries=atom_attention_n_queries,
            n_keys=atom_attention_n_keys
        )
    
    def c_skip(self, timesteps: Tensor) -> Tensor:
        """Computes the skip connection scaling factor from Karras et al. (2022)."""
        return self.sd_data ** 2 / (self.sd_data ** 2 + timesteps ** 2)
    
    def c_out(self, timesteps: Tensor) -> Tensor:
        """Computes the output scaling factor from Karras et al. (2022)."""
        return timesteps * self.sd_data / torch.sqrt(self.sd_data ** 2 + timesteps ** 2)
    
    def c_in(self, timesteps: Tensor) -> Tensor:
        """Computes the input scaling factor from Karras et al. (2022)."""
        return 1. / torch.sqrt(self.sd_data ** 2 + timesteps ** 2)

    def scale_inputs(
            self,
            noisy_atoms: Tensor,
            timesteps: Tensor
    ) -> Tensor:
        """Scales positions to dim

网站公告

今日签到

点亮在社区的每一天
去签到