AlehsanAliyev commited on
Commit
32cc8bc
·
verified ·
1 Parent(s): 5394de5

updating model.py

Browse files
Files changed (1) hide show
  1. model.py +36 -15
model.py CHANGED
@@ -1,15 +1,36 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- class BiLSTMClassifier(nn.Module):
5
- def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
6
- super().__init__()
7
- self.embedding = nn.Embedding(vocab_size, embed_dim)
8
- self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
9
- self.fc = nn.Linear(hidden_dim * 2, num_classes)
10
-
11
- def forward(self, x):
12
- x = self.embedding(x)
13
- _, (h_n, _) = self.lstm(x)
14
- h_cat = torch.cat((h_n[0], h_n[1]), dim=1)
15
- return self.fc(h_cat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import torch.nn as nn
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ class BiLSTMConfig(PretrainedConfig):
7
+ model_type = "bilstm"
8
+
9
+ def __init__(self, vocab_size=64000, embedding_dim=1024, hidden_dim=512, num_labels=3, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.vocab_size = vocab_size
12
+ self.embedding_dim = embedding_dim
13
+ self.hidden_dim = hidden_dim
14
+ self.num_labels = num_labels
15
+
16
+ class BiLSTMClassifier(PreTrainedModel):
17
+ config_class = BiLSTMConfig
18
+
19
+ def __init__(self, config: BiLSTMConfig):
20
+ super().__init__(config)
21
+ self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
22
+ self.lstm = nn.LSTM(config.embedding_dim, config.hidden_dim, batch_first=True, bidirectional=True)
23
+ self.fc = nn.Linear(config.hidden_dim * 2, config.num_labels)
24
+
25
+ self.post_init()
26
+
27
+ def forward(self, input_ids, attention_mask=None, labels=None):
28
+ x = self.embedding(input_ids)
29
+ _, (h_n, _) = self.lstm(x)
30
+ h_cat = torch.cat((h_n[0], h_n[1]), dim=1)
31
+ logits = self.fc(h_cat)
32
+
33
+ if labels is not None:
34
+ loss = F.cross_entropy(logits, labels)
35
+ return {"loss": loss, "logits": logits}
36
+ return {"logits": logits}