【Closure-Hayd】

发布于:2025-05-17 ⋅ 阅读:(16) ⋅ 点赞:(0)

RNA序列本身存在结构上的物理信息,因此可以利用文献提供的相关方法来对RNA序列的物理特征进行更加细致的提取。

  • 几何向量编码(GVP模块)​借鉴Rhodesign模型中的GVP(Geometric Vector Perceptron)模块,将每个核苷酸的原子坐标分解为标量特征(如原子间距离、二面角)​和矢量特征(如C4'-C4'链方向向量)​。例如:

    • 标量特征:计算磷酸骨架(P-O5'-C5'-C4'-C3'-O3')的二面角、键长等几何参数

    • 矢量特征:提取相邻核苷酸C4'原子的空间向量,编码局部骨架方向。

    • 侧链特征:对N1/N9原子与骨架的几何关系进行编码,区分嘧啶和嘌呤碱基。

  • 缺失值处理对NaN填充的原子坐标,采用掩码机制​(masked attention)或插值补全​(基于已知原子的空间分布预测缺失坐标),避免噪声干扰。

Rhodesign模型的github链接:https://github.com/ml4bio/RhoDesign

模型的文章链接:https://www.nature.com/articles/s43588-024-00720-6

 RDesign/model/module.py at master · A4Bio/RDesign

git clone https://github.com/A4Bio/RDesign.git

 eval "$(/mnt/workspace/miniconda3/bin/conda shell.bash hook)"

cd RDesign
conda env create -f environment.yml
conda activate RDesign
class TransformerLayer(nn.Module):
    def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.0):
        super(TransformerLayer, self).__init__()
        self.num_heads = num_heads
        self.num_hidden = num_hidden
        self.num_in = num_in
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.ModuleList([nn.BatchNorm1d(num_hidden) for _ in range(2)])
        self.attention = NeighborAttention(num_hidden, num_hidden + num_in, num_heads)
        self.dense = nn.Sequential(
            nn.Linear(num_hidden, num_hidden*4),
            nn.ReLU(),
            nn.Linear(num_hidden*4, num_hidden)
        )

    def forward(self, h_V, h_E, edge_idx, batch_id=None):
        center_id = edge_idx[0]
        dh = self.attention(h_V, h_E, center_id, batch_id)
        h_V = self.norm[0](h_V + self.dropout(dh))
        dh = self.dense(h_V)
        h_V = self.norm[1](h_V + self.dropout(dh))
        return h_V
class NeighborAttention(nn.Module):
    def __init__(self, num_hidden, num_in, num_heads=4):
        super(NeighborAttention, self).__init__()
        self.num_heads = num_heads
        self.num_hidden = num_hidden
        
        self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False)
        self.W_K = nn.Linear(num_in, num_hidden, bias=False)
        self.W_V = nn.Linear(num_in, num_hidden, bias=False)
        self.Bias = nn.Sequential(
                                nn.Linear(num_hidden*3, num_hidden),
                                nn.ReLU(),
                                nn.Linear(num_hidden,num_hidden),
                                nn.ReLU(),
                                nn.Linear(num_hidden,num_heads)
                                )
        self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)

    def forward(self, h_V, h_E, center_id, batch_id):
        N = h_V.shape[0]
        E = h_E.shape[0]
        n_heads = self.num_heads
        d = int(self.num_hidden / n_heads)

        Q = self.W_Q(h_V).view(N, n_heads, 1, d)[center_id]
        K = self.W_K(h_E).view(E, n_heads, d, 1)
        attend_logits = torch.matmul(Q, K).view(E, n_heads, 1)
        attend_logits = attend_logits / np.sqrt(d)

        V = self.W_V(h_E).view(-1, n_heads, d) 
        attend = scatter_softmax(attend_logits, index=center_id, dim=0)
        h_V = scatter_sum(attend*V, center_id, dim=0).view([N, self.num_hidden])
        h_V_update = self.W_O(h_V)
        return h_V_update


网站公告

今日签到

点亮在社区的每一天
去签到