File size: 7,114 Bytes
012ebc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
from typing import Dict, Any
import yaml
import os
from models import ModernBertForSentiment
class SentimentInference:
def __init__(self, config_path: str = "config.yaml"):
"""Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub."""
print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") # Add this
with open(config_path, 'r') as f:
config_data = yaml.safe_load(f)
print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") # Add this
model_yaml_cfg = config_data.get('model', {})
inference_yaml_cfg = config_data.get('inference', {})
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") # Add this
print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") # Add this
self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))
# --- Tokenizer Loading (always from Hub for now, or could be made conditional) ---
if not tokenizer_hf_repo_id and not model_hf_repo_id:
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
print(f"Loading tokenizer from: {effective_tokenizer_repo_id}")
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
# --- Model Loading --- #
# Determine if we are loading from a local .pt file or from Hugging Face Hub
load_from_local_pt = False
if local_model_weights_path and os.path.isfile(local_model_weights_path):
print(f"Found local model weights path: {local_model_weights_path}")
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
load_from_local_pt = True
elif not model_hf_repo_id:
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
if load_from_local_pt:
print("Attempting to load model from local .pt checkpoint...")
print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
# Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
# This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_hf_repo_id or tokenizer_hf_repo_id)
print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") # Add this
if not base_model_for_config_id:
raise ValueError("For local .pt loading, model.base_model_for_config must be specified in config.yaml (e.g., 'answerdotai/ModernBERT-base') to build the model structure.")
print(f"Loading ModernBertConfig for structure from: {base_model_for_config_id}")
bert_config = ModernBertConfig.from_pretrained(base_model_for_config_id)
# Augment config with parameters from model_yaml_cfg
bert_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
bert_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4)
bert_config.classifier_dropout = model_yaml_cfg.get('dropout')
bert_config.num_labels = model_yaml_cfg.get('num_labels', 1)
# bert_config.loss_function = model_yaml_cfg.get('loss_function') # If needed by __init__
print("Instantiating ModernBertForSentiment model structure...")
self.model = ModernBertForSentiment(bert_config)
print(f"Loading model weights from local checkpoint: {local_model_weights_path}")
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model_state_to_load = checkpoint['model_state_dict']
else:
model_state_to_load = checkpoint # Assume it's the state_dict itself
self.model.load_state_dict(model_state_to_load)
print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.")
else: # Load from Hugging Face Hub
print(f"Attempting to load model from Hugging Face Hub: {model_hf_repo_id}...")
print(f"--- Debug: Entering HUGGING FACE HUB loading path ---") # Add this
print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") # Add this
if not model_hf_repo_id:
raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.")
print(f"Loading base ModernBertConfig from: {model_hf_repo_id}")
loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id)
# Augment loaded_config
loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
loaded_config.classifier_dropout = model_yaml_cfg.get('dropout')
loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1)
print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
self.model = AutoModelForSequenceClassification.from_pretrained(
model_hf_repo_id,
config=loaded_config,
trust_remote_code=True,
force_download=True # <--- TEMPORARY - remove when everything is working
)
print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
self.model.eval()
def predict(self, text: str) -> Dict[str, Any]:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
with torch.no_grad():
outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
logits = outputs.get("logits") # Use .get for safety
if logits is None:
raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
prob = torch.sigmoid(logits).item()
return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob} |