|
import torch |
|
import re |
|
from model import BeastTokenizer, BeastSpamModel |
|
from safetensors.torch import load_file |
|
|
|
def predict_spam(text, tokenizer, model): |
|
cleaned = re.sub(r"\s+", " ", re.sub(r"\W", " ", re.sub(r"http\S+", "", text.lower()))).strip() |
|
encoded = tokenizer.encode(cleaned) |
|
tensor = torch.tensor([encoded], dtype=torch.long) |
|
with torch.no_grad(): |
|
output = model(tensor).item() |
|
return "π₯ It is SPAM!" if output > 0.5 else "β
It is NOT spam." |
|
|
|
if __name__ == "__main__": |
|
print("π© Enter the full email content below (press Enter twice to finish):\n") |
|
lines = [] |
|
while True: |
|
line = input() |
|
if line.strip() == "": |
|
break |
|
lines.append(line) |
|
email = "\n".join(lines) |
|
|
|
|
|
texts = ["this is dummy tokenizer data"] |
|
tokenizer = BeastTokenizer(texts) |
|
|
|
|
|
model = BeastSpamModel(len(tokenizer.word2idx)) |
|
model.load_state_dict(load_file("beast_spam_model.safetensors")) |
|
model.eval() |
|
|
|
print("\n[π] Checking email...") |
|
print(f"[π§ ] Result: {predict_spam(email, tokenizer, model)}") |
|
|