Parker Tope commited on
Commit
fe33e43
Β·
1 Parent(s): 149e6df

adding model

Browse files
Files changed (1) hide show
  1. models/model.py +92 -0
models/model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class RNAStructurePredictor(nn.Module):
6
+ def __init__(self, d_model=128, nhead=8, num_encoder_layers=3, dim_feedforward= 512, dropout=0.01,
7
+ num_structures=5, desc_dim=768):
8
+ super().__init__()
9
+
10
+ self.d_model = d_model
11
+ self.num_structures = num_structures
12
+
13
+ # ── Embeddings ─────────────────────────────────────────────
14
+ self.seq_embedding = nn.Embedding(5, d_model) # A,C,G,U,pad
15
+ self.torsion_projection = nn.Linear(9, d_model)
16
+ self.description_projection = nn.Linear(desc_dim, d_model)
17
+
18
+ # ── BPPM CNN encoder ──────────────────────────────────────
19
+ self.bppm_conv = nn.Sequential(
20
+ nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
21
+ nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
22
+ nn.Conv2d(64, d_model, 3, padding=1), nn.ReLU(),
23
+ )
24
+
25
+ # Positional encoding
26
+ self.positional_encoding = PositionalEncoding(d_model, dropout)
27
+
28
+ # Transformer encoder
29
+ enc_layer = nn.TransformerEncoderLayer(
30
+ d_model, nhead, dim_feedforward, dropout, batch_first=True
31
+ )
32
+ self.transformer_encoder = nn.TransformerEncoder(enc_layer,
33
+ num_encoder_layers)
34
+
35
+ # ── Structure decoders (predict 3-D coords) ───────────────
36
+ self.structure_decoders = nn.ModuleList([
37
+ nn.Sequential(
38
+ nn.Linear(d_model, dim_feedforward), nn.ReLU(), nn.Dropout(dropout),
39
+ nn.Linear(dim_feedforward, dim_feedforward), nn.ReLU(), nn.Dropout(dropout),
40
+ nn.Linear(dim_feedforward, 3)
41
+ ) for _ in range(num_structures)
42
+ ])
43
+
44
+ # ───────────────────────────────────────────────────────────────
45
+ def forward(self, seq, bppm, torsion, description_emb, mask=None):
46
+ """
47
+ mask: Bool tensor [B, L] where 1 = real token, 0 = padding.
48
+ If None, no masking is applied inside the encoder.
49
+ """
50
+ B, L = seq.shape
51
+
52
+ # Embeddings
53
+ seq_embed = self.seq_embedding(seq) # [B,L,D]
54
+ tors_embed = self.torsion_projection(torsion) # [B,L,D]
55
+
56
+ bppm_feat = self.bppm_conv(bppm.unsqueeze(1)) # [B,D,L,L]
57
+ bppm_feat = bppm_feat.mean(dim=2).transpose(1, 2) # [B,L,D]
58
+
59
+ desc_proj = self.description_projection(description_emb) # [B,D]
60
+ desc_proj = desc_proj.unsqueeze(1).expand(-1, L, -1) # [B,L,D]
61
+
62
+ combined = seq_embed + tors_embed + bppm_feat + desc_proj
63
+ combined = self.positional_encoding(combined)
64
+
65
+ # Transformer (apply mask if provided)
66
+ if mask is not None:
67
+ encoded = self.transformer_encoder(
68
+ combined,
69
+ src_key_padding_mask=~mask.bool() # True=pad for PyTorch API
70
+ )
71
+ else:
72
+ encoded = self.transformer_encoder(combined)
73
+
74
+ # Decode multiple structure hypotheses
75
+ out = [dec(encoded) for dec in self.structure_decoders] # n*[B,L,3]
76
+ return torch.stack(out, dim=1) # [B,5,L,3]
77
+
78
+
79
+ class PositionalEncoding(nn.Module):
80
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
81
+ super().__init__()
82
+ self.dropout = nn.Dropout(dropout)
83
+
84
+ pe = torch.zeros(max_len, d_model)
85
+ pos = torch.arange(max_len, dtype=torch.float32).unsqueeze(1)
86
+ div = torch.exp(torch.arange(0, d_model, 2).float()
87
+ * (-math.log(10000.0) / d_model))
88
+ pe[:, 0::2], pe[:, 1::2] = torch.sin(pos * div), torch.cos(pos * div)
89
+ self.register_buffer('pe', pe.unsqueeze(0)) # [1,max_len,D]
90
+
91
+ def forward(self, x): # x: [B,L,D]
92
+ return self.dropout(x + self.pe[:, :x.size(1)])