|
import torch |
|
import torch.nn as nn |
|
from collections import Counter |
|
|
|
class BeastTokenizer: |
|
def __init__(self, texts=[], vocab_size=5000): |
|
self.word2idx = {'<PAD>': 0, '<UNK>': 1} |
|
if texts: |
|
counter = Counter(word for text in texts for word in text.split()) |
|
common = counter.most_common(vocab_size - 2) |
|
self.word2idx.update({word: idx + 2 for idx, (word, _) in enumerate(common)}) |
|
|
|
def encode(self, text, max_len=100): |
|
tokens = [self.word2idx.get(word, 1) for word in text.split()] |
|
return tokens[:max_len] + [0] * (max_len - len(tokens)) |
|
|
|
class BeastSpamModel(nn.Module): |
|
def __init__(self, vocab_size, embed_dim=128, hidden_dim=64): |
|
super().__init__() |
|
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) |
|
self.conv = nn.Conv1d(embed_dim, 128, kernel_size=5, padding=2) |
|
self.lstm = nn.LSTM(128, hidden_dim, batch_first=True, bidirectional=True) |
|
self.fc = nn.Linear(hidden_dim * 2, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x = self.embedding(x) |
|
x = x.permute(0, 2, 1) |
|
x = self.conv(x).permute(0, 2, 1) |
|
lstm_out, _ = self.lstm(x) |
|
out = self.fc(lstm_out[:, -1, :]) |
|
return self.sigmoid(out).squeeze(1) |
|
|