DiffusionModule类
是 AlphaFold3 中的一个关键组件,基于扩散模型(diffusion models)来进行蛋白质结构预测。
源代码:
"""
Diffusion Module from AlphaFold3.
The StructureModule of AlphaFold2 using invariant point attention was replaced with a relatively standard
non-equivariant point-cloud diffusion model over all atoms. The denoiser is based on a modern transformer,
but with several modifications to make it more amenable to the task. The main changes are:
- Conditioning from the trunk in several ways: initialise the activations for the single embedding, use
a variant of Adaptive Layernorm for the single conditioning and logit biasing for the pair conditioning.
- Standard transformer tricks (e.g. SwiGLU) and methods used in AlphaFold2 (gating)
- A two-level architecture, working first on atoms, then tokens, then atoms again.
"""
import math
import torch
from torch import nn
from torch import Tensor
from torch.nn import LayerNorm
from typing import Dict, Tuple
from src.models.diffusion_conditioning import DiffusionConditioning
from src.models.diffusion_transformer import DiffusionTransformer
from src.models.components.atom_attention import AtomAttentionEncoder, AtomAttentionDecoder
from src.models.components.primitives import LinearNoBias
from src.utils.geometry.vector import Vec3Array
from src.diffusion.augmentation import centre_random_augmentation
from src.diffusion.noise import sample_noise_level, noise_positions
class DiffusionModule(torch.nn.Module):
def __init__(
self,
c_atom: int = 128,
c_atompair: int = 16,
c_token: int = 768,
c_tokenpair: int = 128,
atom_encoder_blocks: int = 3,
atom_encoder_heads: int = 16,
dropout: float = 0.0,
atom_attention_n_queries: int = 32,
atom_attention_n_keys: int = 128,
atom_decoder_blocks: int = 3,
atom_decoder_heads: int = 16,
token_transformer_blocks: int = 24,
token_transformer_heads: int = 16,
sd_data: float = 16.0,
s_max: float = 160.0,
s_min: float = 4e-4,
p: float = 7.0,
clear_cache_between_blocks: bool = False,
blocks_per_ckpt: int = 1,
):
super(DiffusionModule, self).__init__()
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_token = c_token
self.c_tokenpair = c_tokenpair
self.atom_encoder_blocks = atom_encoder_blocks
self.atom_encoder_heads = atom_encoder_heads
self.dropout = dropout
self.atom_attention_n_queries = atom_attention_n_queries
self.atom_attention_n_keys = atom_attention_n_keys
self.token_transformer_blocks = token_transformer_blocks
self.token_transformer_heads = token_transformer_heads
self.sd_data = sd_data
self.s_max = s_max
self.s_min = s_min
self.p = p
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks_per_ckpt = blocks_per_ckpt
# Conditioning
self.diffusion_conditioning = DiffusionConditioning(
c_token=c_token,
c_pair=c_tokenpair,
sd_data=sd_data
)
# Sequence-local atom attention and aggregation to coarse-grained tokens
self.atom_attention_encoder = AtomAttentionEncoder(
c_token=c_token,
c_atom=c_atom,
c_atompair=c_atompair,
c_trunk_pair=c_tokenpair,
no_blocks=atom_decoder_blocks,
no_heads=atom_encoder_heads,
dropout=dropout,
n_queries=atom_attention_n_queries,
n_keys=atom_attention_n_keys,
trunk_conditioning=True,
clear_cache_between_blocks=clear_cache_between_blocks
)
# Full self-attention on token level
self.token_proj = nn.Sequential(
LayerNorm(c_token),
LinearNoBias(c_token, c_token, init='final')
)
self.diffusion_transformer = DiffusionTransformer(
c_token=c_token,
c_pair=c_tokenpair,
no_blocks=token_transformer_blocks,
no_heads=token_transformer_heads,
dropout=dropout,
clear_cache_between_blocks=clear_cache_between_blocks,
blocks_per_ckpt=blocks_per_ckpt,
)
self.token_post_layer_norm = LayerNorm(c_token)
# Broadcast token activations to atoms and run sequence-local atom attention
self.atom_attention_decoder = AtomAttentionDecoder(
c_token=c_token,
c_atom=c_atom,
c_atompair=c_atompair,
no_blocks=atom_decoder_blocks,
no_heads=atom_decoder_heads,
dropout=dropout,
n_queries=atom_attention_n_queries,
n_keys=atom_attention_n_keys
)
def c_skip(self, timesteps: Tensor) -> Tensor:
"""Computes the skip connection scaling factor from Karras et al. (2022)."""
return self.sd_data ** 2 / (self.sd_data ** 2 + timesteps ** 2)
def c_out(self, timesteps: Tensor) -> Tensor:
"""Computes the output scaling factor from Karras et al. (2022)."""
return timesteps * self.sd_data / torch.sqrt(self.sd_data ** 2 + timesteps ** 2)
def c_in(self, timesteps: Tensor) -> Tensor:
"""Computes the input scaling factor from Karras et al. (2022)."""
return 1. / torch.sqrt(self.sd_data ** 2 + timesteps ** 2)
def scale_inputs(
self,
noisy_atoms: Tensor,
timesteps: Tensor
) -> Tensor:
"""Scales positions to dim