import torch.nn as nn class LSTMClassifier(nn.Module): def __init__(self, vocab_size, embed_dim=256, hidden_dim=256, n_layers=2, dropout=0.3, bidirectional=True): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional) self.bi = 2 if bidirectional else 1 self.fc = nn.Linear(hidden_dim * self.bi, 1) def forward(self, x, lengths): x = self.embedding(x) packed = nn.utils.rnn.pack_padded_sequence( x, lengths.cpu(), batch_first=True, enforce_sorted=False) _, (h, _) = self.lstm(packed) if self.bi == 2: h = torch.cat((h[-2], h[-1]), dim=1) # concat fwd+rev else: h = h[-1] return self.fc(h).squeeze(1)