import torch, pickle, json, string, nltk from pathlib import Path from lstm_model import LSTMClassifier PAD = 0 UNK = 1 ROOT = Path(__file__).resolve().parent cfg = json.load(open(ROOT/'config.json')) vocab = pickle.load(open(ROOT/'vocab.pkl', 'rb')) model = LSTMClassifier(**cfg).eval() model.load_state_dict(torch.load(ROOT/'pytorch_model.bin', map_location='cpu')) nltk.download('stopwords', quiet=True) STOP = set(nltk.corpus.stopwords.words('english')) PUNC = str.maketrans('', '', string.punctuation) def preprocess(text): text = text.lower().translate(PUNC) toks = [w for w in text.split() if w not in STOP] return toks[: cfg['pad_len']] def encode(tokens): ids = [vocab.get(w, UNK) for w in tokens] ids += [PAD] * (cfg['pad_len'] - len(ids)) return torch.tensor(ids).unsqueeze(0), torch.tensor([len(tokens)]) @torch.no_grad() def predict(text): x, length = encode(preprocess(text)) logit = model(x, length) prob = torch.sigmoid(logit).item() return prob # 0-1, >0.5 → positive