|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig |
|
from typing import Dict, Any |
|
import yaml |
|
import os |
|
from models import ModernBertForSentiment |
|
|
|
class SentimentInference: |
|
def __init__(self, config_path: str = "config.yaml"): |
|
"""Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub.""" |
|
print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") |
|
with open(config_path, 'r') as f: |
|
config_data = yaml.safe_load(f) |
|
print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") |
|
|
|
model_yaml_cfg = config_data.get('model', {}) |
|
inference_yaml_cfg = config_data.get('inference', {}) |
|
|
|
model_hf_repo_id = model_yaml_cfg.get('name_or_path') |
|
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id) |
|
local_model_weights_path = inference_yaml_cfg.get('model_path') |
|
|
|
print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") |
|
print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") |
|
|
|
self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512)) |
|
|
|
|
|
if not tokenizer_hf_repo_id and not model_hf_repo_id: |
|
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml") |
|
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id |
|
print(f"Loading tokenizer from: {effective_tokenizer_repo_id}") |
|
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id) |
|
|
|
|
|
|
|
load_from_local_pt = False |
|
if local_model_weights_path and os.path.isfile(local_model_weights_path): |
|
print(f"Found local model weights path: {local_model_weights_path}") |
|
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") |
|
load_from_local_pt = True |
|
elif not model_hf_repo_id: |
|
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml") |
|
|
|
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") |
|
|
|
if load_from_local_pt: |
|
print("Attempting to load model from local .pt checkpoint...") |
|
print("--- Debug: Entering LOCAL .pt loading path ---") |
|
|
|
|
|
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_hf_repo_id or tokenizer_hf_repo_id) |
|
print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") |
|
if not base_model_for_config_id: |
|
raise ValueError("For local .pt loading, model.base_model_for_config must be specified in config.yaml (e.g., 'answerdotai/ModernBERT-base') to build the model structure.") |
|
|
|
print(f"Loading ModernBertConfig for structure from: {base_model_for_config_id}") |
|
bert_config = ModernBertConfig.from_pretrained(base_model_for_config_id) |
|
|
|
|
|
bert_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean') |
|
bert_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4) |
|
bert_config.classifier_dropout = model_yaml_cfg.get('dropout') |
|
bert_config.num_labels = model_yaml_cfg.get('num_labels', 1) |
|
|
|
|
|
print("Instantiating ModernBertForSentiment model structure...") |
|
self.model = ModernBertForSentiment(bert_config) |
|
|
|
print(f"Loading model weights from local checkpoint: {local_model_weights_path}") |
|
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu')) |
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
|
model_state_to_load = checkpoint['model_state_dict'] |
|
else: |
|
model_state_to_load = checkpoint |
|
self.model.load_state_dict(model_state_to_load) |
|
print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.") |
|
|
|
else: |
|
print(f"Attempting to load model from Hugging Face Hub: {model_hf_repo_id}...") |
|
print(f"--- Debug: Entering HUGGING FACE HUB loading path ---") |
|
print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") |
|
if not model_hf_repo_id: |
|
raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.") |
|
|
|
print(f"Loading base ModernBertConfig from: {model_hf_repo_id}") |
|
loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id) |
|
|
|
|
|
loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean') |
|
loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 6) |
|
loaded_config.classifier_dropout = model_yaml_cfg.get('dropout') |
|
loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1) |
|
|
|
print(f"Instantiating and loading model weights for {model_hf_repo_id}...") |
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
model_hf_repo_id, |
|
config=loaded_config, |
|
trust_remote_code=True, |
|
force_download=True |
|
) |
|
print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.") |
|
|
|
self.model.eval() |
|
|
|
def predict(self, text: str) -> Dict[str, Any]: |
|
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True) |
|
with torch.no_grad(): |
|
outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
|
logits = outputs.get("logits") |
|
if logits is None: |
|
raise ValueError("Model output did not contain 'logits'. Check model's forward pass.") |
|
prob = torch.sigmoid(logits).item() |
|
return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob} |