AF3 Transition和ConditionedTransitionBlock类解读

发布于:2025-02-10 ⋅ 阅读:(32) ⋅ 点赞:(0)
AlphaFold3的Transition/ ConditionedTransitionBlock类
1. Transition 类
  • 作用

    • 提升模型的表达能力,通过扩展和收缩通道,学习不同层次的特征。
    • 作为残差块的一部分,帮助模型捕捉更复杂的序列-结构映射关系。
  • 生物学意义

    • 帮助捕捉蛋白质序列中局部和全局特征,为后续模块提供更丰富的特征。
2. ConditionedTransitionBlock 类
  • 作用

    • 通过条件张量 s 自适应调整输入特征 a 的分布。
    • 门控机制控制特征的更新量,避免过拟合,增强模型的条件依赖能力。
  • 生物学意义

    • 模拟蛋白质在不同环境或上下文(如特定配体、化学环境)下的行为。
    • 允许模型根据条件信息动态调整特征表示,捕捉更细粒度的结构变化。

源代码:

"""Transition blocks in AlphaFold3"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from src.models.components.primitives import AdaLN
from src.models.components.primitives import Linear, LinearNoBias


class Transition(nn.Module):
    """A transition block for a residual update."""
    def __init__(self, input_dim: int, n: int = 4):
        """
        Args:
            input_dim:
                Channels of the input tensor
            n:
                channel expansion factor for hidden dimensions
        """
        super(Transition, self).__init__()
        self.layer_norm = LayerNorm(input_dim)
        self.linear_1 = LinearNoBias(input_dim, n * input_dim, init='relu')
        self.linear_2 = LinearNoBias(input_dim, n * input_dim, init='default')
        self.output_linear = LinearNoBias(input_dim * n, input_dim, init='final')

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.silu(self.linear_1(x)) * self.linear_2(x)
        return self.output_linear(x)


class ConditionedTransitionBlock(nn.Module):
    """SwiGLU transition block with adaptive layer norm."""
    def __init__(self,
                 input_dim: int,
                 n: int = 2):
        """
        Args:
            input_dim:
                Channels of the input tensor
            n:
                channel expansion factor for hidden dimensions
        """
        super(ConditionedTransitionBlock, self).__init__()
        self.ada_ln = AdaLN(input_dim)
        self.hidden_gating_linear = LinearNoBias(input_dim, n * input_dim, init='relu')
        self.hidden_linear = LinearNoBias(input_dim, n * input_dim, init='default')
        self.output_linear = L

网站公告

今日签到

点亮在社区的每一天
去签到