File size: 1,040 Bytes
9eced06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

import torch, pickle, json, string, nltk
from pathlib import Path
from lstm_model import LSTMClassifier

PAD = 0
UNK = 1
ROOT = Path(__file__).resolve().parent

cfg   = json.load(open(ROOT/'config.json'))
vocab = pickle.load(open(ROOT/'vocab.pkl', 'rb'))

model = LSTMClassifier(**cfg).eval()
model.load_state_dict(torch.load(ROOT/'pytorch_model.bin', map_location='cpu'))

nltk.download('stopwords', quiet=True)
STOP = set(nltk.corpus.stopwords.words('english'))
PUNC = str.maketrans('', '', string.punctuation)

def preprocess(text):
    text = text.lower().translate(PUNC)
    toks = [w for w in text.split() if w not in STOP]
    return toks[: cfg['pad_len']]

def encode(tokens):
    ids = [vocab.get(w, UNK) for w in tokens]
    ids += [PAD] * (cfg['pad_len'] - len(ids))
    return torch.tensor(ids).unsqueeze(0), torch.tensor([len(tokens)])

@torch.no_grad()
def predict(text):
    x, length = encode(preprocess(text))
    logit = model(x, length)
    prob  = torch.sigmoid(logit).item()
    return prob   # 0-1, >0.5 → positive