Spaces:
Sleeping
Sleeping
Parker Tope
commited on
Commit
Β·
fe33e43
1
Parent(s):
149e6df
adding model
Browse files- models/model.py +92 -0
models/model.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
class RNAStructurePredictor(nn.Module):
|
6 |
+
def __init__(self, d_model=128, nhead=8, num_encoder_layers=3, dim_feedforward= 512, dropout=0.01,
|
7 |
+
num_structures=5, desc_dim=768):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.d_model = d_model
|
11 |
+
self.num_structures = num_structures
|
12 |
+
|
13 |
+
# ββ Embeddings βββββββββββββββββββββββββββββββββββββββββββββ
|
14 |
+
self.seq_embedding = nn.Embedding(5, d_model) # A,C,G,U,pad
|
15 |
+
self.torsion_projection = nn.Linear(9, d_model)
|
16 |
+
self.description_projection = nn.Linear(desc_dim, d_model)
|
17 |
+
|
18 |
+
# ββ BPPM CNN encoder ββββββββββββββββββββββββββββββββββββββ
|
19 |
+
self.bppm_conv = nn.Sequential(
|
20 |
+
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
|
21 |
+
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
|
22 |
+
nn.Conv2d(64, d_model, 3, padding=1), nn.ReLU(),
|
23 |
+
)
|
24 |
+
|
25 |
+
# Positional encoding
|
26 |
+
self.positional_encoding = PositionalEncoding(d_model, dropout)
|
27 |
+
|
28 |
+
# Transformer encoder
|
29 |
+
enc_layer = nn.TransformerEncoderLayer(
|
30 |
+
d_model, nhead, dim_feedforward, dropout, batch_first=True
|
31 |
+
)
|
32 |
+
self.transformer_encoder = nn.TransformerEncoder(enc_layer,
|
33 |
+
num_encoder_layers)
|
34 |
+
|
35 |
+
# ββ Structure decoders (predict 3-D coords) βββββββββββββββ
|
36 |
+
self.structure_decoders = nn.ModuleList([
|
37 |
+
nn.Sequential(
|
38 |
+
nn.Linear(d_model, dim_feedforward), nn.ReLU(), nn.Dropout(dropout),
|
39 |
+
nn.Linear(dim_feedforward, dim_feedforward), nn.ReLU(), nn.Dropout(dropout),
|
40 |
+
nn.Linear(dim_feedforward, 3)
|
41 |
+
) for _ in range(num_structures)
|
42 |
+
])
|
43 |
+
|
44 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
45 |
+
def forward(self, seq, bppm, torsion, description_emb, mask=None):
|
46 |
+
"""
|
47 |
+
mask: Bool tensor [B, L] where 1 = real token, 0 = padding.
|
48 |
+
If None, no masking is applied inside the encoder.
|
49 |
+
"""
|
50 |
+
B, L = seq.shape
|
51 |
+
|
52 |
+
# Embeddings
|
53 |
+
seq_embed = self.seq_embedding(seq) # [B,L,D]
|
54 |
+
tors_embed = self.torsion_projection(torsion) # [B,L,D]
|
55 |
+
|
56 |
+
bppm_feat = self.bppm_conv(bppm.unsqueeze(1)) # [B,D,L,L]
|
57 |
+
bppm_feat = bppm_feat.mean(dim=2).transpose(1, 2) # [B,L,D]
|
58 |
+
|
59 |
+
desc_proj = self.description_projection(description_emb) # [B,D]
|
60 |
+
desc_proj = desc_proj.unsqueeze(1).expand(-1, L, -1) # [B,L,D]
|
61 |
+
|
62 |
+
combined = seq_embed + tors_embed + bppm_feat + desc_proj
|
63 |
+
combined = self.positional_encoding(combined)
|
64 |
+
|
65 |
+
# Transformer (apply mask if provided)
|
66 |
+
if mask is not None:
|
67 |
+
encoded = self.transformer_encoder(
|
68 |
+
combined,
|
69 |
+
src_key_padding_mask=~mask.bool() # True=pad for PyTorch API
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
encoded = self.transformer_encoder(combined)
|
73 |
+
|
74 |
+
# Decode multiple structure hypotheses
|
75 |
+
out = [dec(encoded) for dec in self.structure_decoders] # n*[B,L,3]
|
76 |
+
return torch.stack(out, dim=1) # [B,5,L,3]
|
77 |
+
|
78 |
+
|
79 |
+
class PositionalEncoding(nn.Module):
|
80 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
81 |
+
super().__init__()
|
82 |
+
self.dropout = nn.Dropout(dropout)
|
83 |
+
|
84 |
+
pe = torch.zeros(max_len, d_model)
|
85 |
+
pos = torch.arange(max_len, dtype=torch.float32).unsqueeze(1)
|
86 |
+
div = torch.exp(torch.arange(0, d_model, 2).float()
|
87 |
+
* (-math.log(10000.0) / d_model))
|
88 |
+
pe[:, 0::2], pe[:, 1::2] = torch.sin(pos * div), torch.cos(pos * div)
|
89 |
+
self.register_buffer('pe', pe.unsqueeze(0)) # [1,max_len,D]
|
90 |
+
|
91 |
+
def forward(self, x): # x: [B,L,D]
|
92 |
+
return self.dropout(x + self.pe[:, :x.size(1)])
|