File size: 1,040 Bytes
9eced06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch, pickle, json, string, nltk
from pathlib import Path
from lstm_model import LSTMClassifier
PAD = 0
UNK = 1
ROOT = Path(__file__).resolve().parent
cfg = json.load(open(ROOT/'config.json'))
vocab = pickle.load(open(ROOT/'vocab.pkl', 'rb'))
model = LSTMClassifier(**cfg).eval()
model.load_state_dict(torch.load(ROOT/'pytorch_model.bin', map_location='cpu'))
nltk.download('stopwords', quiet=True)
STOP = set(nltk.corpus.stopwords.words('english'))
PUNC = str.maketrans('', '', string.punctuation)
def preprocess(text):
text = text.lower().translate(PUNC)
toks = [w for w in text.split() if w not in STOP]
return toks[: cfg['pad_len']]
def encode(tokens):
ids = [vocab.get(w, UNK) for w in tokens]
ids += [PAD] * (cfg['pad_len'] - len(ids))
return torch.tensor(ids).unsqueeze(0), torch.tensor([len(tokens)])
@torch.no_grad()
def predict(text):
x, length = encode(preprocess(text))
logit = model(x, length)
prob = torch.sigmoid(logit).item()
return prob # 0-1, >0.5 → positive
|