SMS-spam-detection / model.py
abdullahalioo's picture
Upload 6 files
f02a16d verified
import torch
import torch.nn as nn
from collections import Counter
class BeastTokenizer:
def __init__(self, texts=[], vocab_size=5000):
self.word2idx = {'<PAD>': 0, '<UNK>': 1}
if texts:
counter = Counter(word for text in texts for word in text.split())
common = counter.most_common(vocab_size - 2)
self.word2idx.update({word: idx + 2 for idx, (word, _) in enumerate(common)})
def encode(self, text, max_len=100):
tokens = [self.word2idx.get(word, 1) for word in text.split()]
return tokens[:max_len] + [0] * (max_len - len(tokens))
class BeastSpamModel(nn.Module):
def __init__(self, vocab_size, embed_dim=128, hidden_dim=64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.conv = nn.Conv1d(embed_dim, 128, kernel_size=5, padding=2)
self.lstm = nn.LSTM(128, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x)
x = x.permute(0, 2, 1)
x = self.conv(x).permute(0, 2, 1)
lstm_out, _ = self.lstm(x)
out = self.fc(lstm_out[:, -1, :])
return self.sigmoid(out).squeeze(1)