File size: 7,114 Bytes
012ebc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
from typing import Dict, Any
import yaml
import os
from models import ModernBertForSentiment

class SentimentInference:
    def __init__(self, config_path: str = "config.yaml"):
        """Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub."""
        print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") # Add this
        with open(config_path, 'r') as f:
            config_data = yaml.safe_load(f)
        print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") # Add this
        
        model_yaml_cfg = config_data.get('model', {})
        inference_yaml_cfg = config_data.get('inference', {})
        
        model_hf_repo_id = model_yaml_cfg.get('name_or_path')
        tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
        local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file

        print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") # Add this
        print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") # Add this

        self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))

        # --- Tokenizer Loading (always from Hub for now, or could be made conditional) ---
        if not tokenizer_hf_repo_id and not model_hf_repo_id:
            raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
        effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
        print(f"Loading tokenizer from: {effective_tokenizer_repo_id}")
        self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)

        # --- Model Loading --- #
        # Determine if we are loading from a local .pt file or from Hugging Face Hub
        load_from_local_pt = False
        if local_model_weights_path and os.path.isfile(local_model_weights_path):
            print(f"Found local model weights path: {local_model_weights_path}")
            print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
            load_from_local_pt = True
        elif not model_hf_repo_id:
            raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")

        print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this

        if load_from_local_pt:
            print("Attempting to load model from local .pt checkpoint...")
            print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
            # Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
            # This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
            base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_hf_repo_id or tokenizer_hf_repo_id)
            print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") # Add this
            if not base_model_for_config_id:
                 raise ValueError("For local .pt loading, model.base_model_for_config must be specified in config.yaml (e.g., 'answerdotai/ModernBERT-base') to build the model structure.")
            
            print(f"Loading ModernBertConfig for structure from: {base_model_for_config_id}")
            bert_config = ModernBertConfig.from_pretrained(base_model_for_config_id)
            
            # Augment config with parameters from model_yaml_cfg
            bert_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
            bert_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4)
            bert_config.classifier_dropout = model_yaml_cfg.get('dropout')
            bert_config.num_labels = model_yaml_cfg.get('num_labels', 1)
            # bert_config.loss_function = model_yaml_cfg.get('loss_function') # If needed by __init__

            print("Instantiating ModernBertForSentiment model structure...")
            self.model = ModernBertForSentiment(bert_config)
            
            print(f"Loading model weights from local checkpoint: {local_model_weights_path}")
            checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                model_state_to_load = checkpoint['model_state_dict']
            else:
                model_state_to_load = checkpoint # Assume it's the state_dict itself
            self.model.load_state_dict(model_state_to_load)
            print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.")

        else: # Load from Hugging Face Hub
            print(f"Attempting to load model from Hugging Face Hub: {model_hf_repo_id}...")
            print(f"--- Debug: Entering HUGGING FACE HUB loading path ---") # Add this
            print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") # Add this
            if not model_hf_repo_id:
                raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.")

            print(f"Loading base ModernBertConfig from: {model_hf_repo_id}")
            loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id)
            
            # Augment loaded_config
            loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
            loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
            loaded_config.classifier_dropout = model_yaml_cfg.get('dropout')
            loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1)

            print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_hf_repo_id,
                config=loaded_config,
                trust_remote_code=True,
                force_download=True  # <--- TEMPORARY - remove when everything is working
            )
            print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
        
        self.model.eval()
        
    def predict(self, text: str) -> Dict[str, Any]:
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
        with torch.no_grad():
            outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        logits = outputs.get("logits") # Use .get for safety
        if logits is None:
            raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
        prob = torch.sigmoid(logits).item()
        return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob}