landify-scam-detector / app /predictor.py
anh-khoa-nguyen
init
f606230
# app/predictor.py
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import py_vncorenlp
from pathlib import Path # <-- THAY ĐỔI 1: Import Path
class Predictor:
def __init__(self):
"""
Hàm khởi tạo, thực hiện tải tất cả các model cần thiết MỘT LẦN
khi service khởi động.
"""
# <-- THAY ĐỔI 2: Xây dựng đường dẫn tuyệt đối
# Lấy đường dẫn của file hiện tại (predictor.py)
model_hub_path = "dorangao/landify-scam-detector-phobert"
current_file_path = Path(__file__).resolve()
# Đi lên thư mục cha (từ app/ -> scam_detector/)
project_root = current_file_path.parent.parent
vncorenlp_path = project_root / "vncorenlp_model"
# --- Kết thúc thay đổi ---
print(f"Loading model from Hugging Face Hub: {model_hub_path}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(model_hub_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_hub_path)
self.model.to(self.device)
self.model.eval()
self.vncorenlp_segmenter = py_vncorenlp.VnCoreNLP(save_dir=str(vncorenlp_path))
print("Models loaded successfully!")
def predict(self, title: str, content: str, threshold: float = 0.8) -> dict:
# ... (Phần còn lại của hàm predict không thay đổi)
text = f"{title}. {content}"
segmented_sentences = self.vncorenlp_segmenter.word_segment(text)
segmented_text = " ".join(segmented_sentences)
inputs = self.tokenizer(
segmented_text,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
)
inputs = {key: val.to(self.device) for key, val in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1).squeeze()
scam_score = probabilities[1].item()
is_scam = scam_score >= threshold
return {
"scam_score": round(scam_score, 4),
"is_scam": bool(is_scam),
# <-- THAY ĐỔI 3 (sẽ thực hiện ở bước sau): Đổi tên trường này
"version": "phobert-v1"
}