import torch import torch.nn as nn import math class RNAStructurePredictor(nn.Module): def __init__(self, d_model=128, nhead=8, num_encoder_layers=3, dim_feedforward= 512, dropout=0.01, num_structures=5, desc_dim=768): super().__init__() self.d_model = d_model self.num_structures = num_structures # ── Embeddings ───────────────────────────────────────────── self.seq_embedding = nn.Embedding(5, d_model) # A,C,G,U,pad self.torsion_projection = nn.Linear(9, d_model) self.description_projection = nn.Linear(desc_dim, d_model) # ── BPPM CNN encoder ────────────────────────────────────── self.bppm_conv = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, d_model, 3, padding=1), nn.ReLU(), ) # Positional encoding self.positional_encoding = PositionalEncoding(d_model, dropout) # Transformer encoder enc_layer = nn.TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder(enc_layer, num_encoder_layers) # ── Structure decoders (predict 3-D coords) ─────────────── self.structure_decoders = nn.ModuleList([ nn.Sequential( nn.Linear(d_model, dim_feedforward), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, dim_feedforward), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, 3) ) for _ in range(num_structures) ]) # ─────────────────────────────────────────────────────────────── def forward(self, seq, bppm, torsion, description_emb, mask=None): """ mask: Bool tensor [B, L] where 1 = real token, 0 = padding. If None, no masking is applied inside the encoder. """ B, L = seq.shape # Embeddings seq_embed = self.seq_embedding(seq) # [B,L,D] tors_embed = self.torsion_projection(torsion) # [B,L,D] bppm_feat = self.bppm_conv(bppm.unsqueeze(1)) # [B,D,L,L] bppm_feat = bppm_feat.mean(dim=2).transpose(1, 2) # [B,L,D] desc_proj = self.description_projection(description_emb) # [B,D] desc_proj = desc_proj.unsqueeze(1).expand(-1, L, -1) # [B,L,D] combined = seq_embed + tors_embed + bppm_feat + desc_proj combined = self.positional_encoding(combined) # Transformer (apply mask if provided) if mask is not None: encoded = self.transformer_encoder( combined, src_key_padding_mask=~mask.bool() # True=pad for PyTorch API ) else: encoded = self.transformer_encoder(combined) # Decode multiple structure hypotheses out = [dec(encoded) for dec in self.structure_decoders] # n*[B,L,3] return torch.stack(out, dim=1) # [B,5,L,3] class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) pos = torch.arange(max_len, dtype=torch.float32).unsqueeze(1) div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2], pe[:, 1::2] = torch.sin(pos * div), torch.cos(pos * div) self.register_buffer('pe', pe.unsqueeze(0)) # [1,max_len,D] def forward(self, x): # x: [B,L,D] return self.dropout(x + self.pe[:, :x.size(1)])