RNA序列本身存在结构上的物理信息,因此可以利用文献提供的相关方法来对RNA序列的物理特征进行更加细致的提取。
几何向量编码(GVP模块)借鉴Rhodesign模型中的GVP(Geometric Vector Perceptron)模块,将每个核苷酸的原子坐标分解为标量特征(如原子间距离、二面角)和矢量特征(如C4'-C4'链方向向量)。例如:
标量特征:计算磷酸骨架(P-O5'-C5'-C4'-C3'-O3')的二面角、键长等几何参数
矢量特征:提取相邻核苷酸C4'原子的空间向量,编码局部骨架方向。
侧链特征:对N1/N9原子与骨架的几何关系进行编码,区分嘧啶和嘌呤碱基。
缺失值处理对NaN填充的原子坐标,采用掩码机制(masked attention)或插值补全(基于已知原子的空间分布预测缺失坐标),避免噪声干扰。
Rhodesign模型的github链接:https://github.com/ml4bio/RhoDesign
模型的文章链接:https://www.nature.com/articles/s43588-024-00720-6
RDesign/model/module.py at master · A4Bio/RDesign
git clone https://github.com/A4Bio/RDesign.git
eval "$(/mnt/workspace/miniconda3/bin/conda shell.bash hook)"
cd RDesign
conda env create -f environment.yml
conda activate RDesign
class TransformerLayer(nn.Module):
def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.0):
super(TransformerLayer, self).__init__()
self.num_heads = num_heads
self.num_hidden = num_hidden
self.num_in = num_in
self.dropout = nn.Dropout(dropout)
self.norm = nn.ModuleList([nn.BatchNorm1d(num_hidden) for _ in range(2)])
self.attention = NeighborAttention(num_hidden, num_hidden + num_in, num_heads)
self.dense = nn.Sequential(
nn.Linear(num_hidden, num_hidden*4),
nn.ReLU(),
nn.Linear(num_hidden*4, num_hidden)
)
def forward(self, h_V, h_E, edge_idx, batch_id=None):
center_id = edge_idx[0]
dh = self.attention(h_V, h_E, center_id, batch_id)
h_V = self.norm[0](h_V + self.dropout(dh))
dh = self.dense(h_V)
h_V = self.norm[1](h_V + self.dropout(dh))
return h_V
class NeighborAttention(nn.Module):
def __init__(self, num_hidden, num_in, num_heads=4):
super(NeighborAttention, self).__init__()
self.num_heads = num_heads
self.num_hidden = num_hidden
self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False)
self.W_K = nn.Linear(num_in, num_hidden, bias=False)
self.W_V = nn.Linear(num_in, num_hidden, bias=False)
self.Bias = nn.Sequential(
nn.Linear(num_hidden*3, num_hidden),
nn.ReLU(),
nn.Linear(num_hidden,num_hidden),
nn.ReLU(),
nn.Linear(num_hidden,num_heads)
)
self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
def forward(self, h_V, h_E, center_id, batch_id):
N = h_V.shape[0]
E = h_E.shape[0]
n_heads = self.num_heads
d = int(self.num_hidden / n_heads)
Q = self.W_Q(h_V).view(N, n_heads, 1, d)[center_id]
K = self.W_K(h_E).view(E, n_heads, d, 1)
attend_logits = torch.matmul(Q, K).view(E, n_heads, 1)
attend_logits = attend_logits / np.sqrt(d)
V = self.W_V(h_E).view(-1, n_heads, d)
attend = scatter_softmax(attend_logits, index=center_id, dim=0)
h_V = scatter_sum(attend*V, center_id, dim=0).view([N, self.num_hidden])
h_V_update = self.W_O(h_V)
return h_V_update