Upload fine-tuned model, tokenizer, and supporting files for modernbert-imdb-sentiment
Browse files- config.yaml +1 -0
- inference.py +83 -35
config.yaml
CHANGED
@@ -4,6 +4,7 @@ model:
|
|
4 |
max_length: 880 # 256
|
5 |
dropout: 0.1
|
6 |
pooling_strategy: "mean" # Current default, change as needed
|
|
|
7 |
|
8 |
inference:
|
9 |
# Default path, can be overridden
|
|
|
4 |
max_length: 880 # 256
|
5 |
dropout: 0.1
|
6 |
pooling_strategy: "mean" # Current default, change as needed
|
7 |
+
num_weighted_layers: 6 # Match original training config
|
8 |
|
9 |
inference:
|
10 |
# Default path, can be overridden
|
inference.py
CHANGED
@@ -1,58 +1,106 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
|
3 |
-
# models.py (containing ModernBertForSentiment) will be loaded from the Hub due to trust_remote_code=True
|
4 |
from typing import Dict, Any
|
5 |
import yaml
|
|
|
|
|
6 |
|
7 |
class SentimentInference:
|
8 |
def __init__(self, config_path: str = "config.yaml"):
|
9 |
-
"""Load configuration and initialize model and tokenizer from Hugging Face Hub."""
|
|
|
10 |
with open(config_path, 'r') as f:
|
11 |
config_data = yaml.safe_load(f)
|
|
|
12 |
|
13 |
model_yaml_cfg = config_data.get('model', {})
|
14 |
inference_yaml_cfg = config_data.get('inference', {})
|
15 |
|
16 |
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
|
17 |
-
if not model_hf_repo_id:
|
18 |
-
raise ValueError("model.name_or_path must be specified in config.yaml (e.g., 'username/model_name')")
|
19 |
-
|
20 |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
|
|
|
|
|
|
|
|
|
21 |
|
22 |
self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
# but if ModernBertForSentiment.__init__ requires it, it must be provided.
|
43 |
-
# Assuming it's not critical for basic inference here to simplify.
|
44 |
-
# loaded_config.loss_function = model_yaml_cfg.get('loss_function', {'name': '...', 'params': {}})
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
self.model.eval()
|
55 |
-
print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
|
56 |
|
57 |
def predict(self, text: str) -> Dict[str, Any]:
|
58 |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
|
|
|
3 |
from typing import Dict, Any
|
4 |
import yaml
|
5 |
+
import os
|
6 |
+
from models import ModernBertForSentiment
|
7 |
|
8 |
class SentimentInference:
|
9 |
def __init__(self, config_path: str = "config.yaml"):
|
10 |
+
"""Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub."""
|
11 |
+
print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") # Add this
|
12 |
with open(config_path, 'r') as f:
|
13 |
config_data = yaml.safe_load(f)
|
14 |
+
print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") # Add this
|
15 |
|
16 |
model_yaml_cfg = config_data.get('model', {})
|
17 |
inference_yaml_cfg = config_data.get('inference', {})
|
18 |
|
19 |
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
|
|
|
|
|
|
|
20 |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
|
21 |
+
local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
|
22 |
+
|
23 |
+
print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") # Add this
|
24 |
+
print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") # Add this
|
25 |
|
26 |
self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))
|
27 |
|
28 |
+
# --- Tokenizer Loading (always from Hub for now, or could be made conditional) ---
|
29 |
+
if not tokenizer_hf_repo_id and not model_hf_repo_id:
|
30 |
+
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
|
31 |
+
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
|
32 |
+
print(f"Loading tokenizer from: {effective_tokenizer_repo_id}")
|
33 |
+
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
|
34 |
+
|
35 |
+
# --- Model Loading --- #
|
36 |
+
# Determine if we are loading from a local .pt file or from Hugging Face Hub
|
37 |
+
load_from_local_pt = False
|
38 |
+
if local_model_weights_path and os.path.isfile(local_model_weights_path):
|
39 |
+
print(f"Found local model weights path: {local_model_weights_path}")
|
40 |
+
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
|
41 |
+
load_from_local_pt = True
|
42 |
+
elif not model_hf_repo_id:
|
43 |
+
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
|
44 |
+
|
45 |
+
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
|
|
|
|
|
|
|
46 |
|
47 |
+
if load_from_local_pt:
|
48 |
+
print("Attempting to load model from local .pt checkpoint...")
|
49 |
+
print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
|
50 |
+
# Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
|
51 |
+
# This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
|
52 |
+
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_hf_repo_id or tokenizer_hf_repo_id)
|
53 |
+
print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") # Add this
|
54 |
+
if not base_model_for_config_id:
|
55 |
+
raise ValueError("For local .pt loading, model.base_model_for_config must be specified in config.yaml (e.g., 'answerdotai/ModernBERT-base') to build the model structure.")
|
56 |
+
|
57 |
+
print(f"Loading ModernBertConfig for structure from: {base_model_for_config_id}")
|
58 |
+
bert_config = ModernBertConfig.from_pretrained(base_model_for_config_id)
|
59 |
+
|
60 |
+
# Augment config with parameters from model_yaml_cfg
|
61 |
+
bert_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
|
62 |
+
bert_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4)
|
63 |
+
bert_config.classifier_dropout = model_yaml_cfg.get('dropout')
|
64 |
+
bert_config.num_labels = model_yaml_cfg.get('num_labels', 1)
|
65 |
+
# bert_config.loss_function = model_yaml_cfg.get('loss_function') # If needed by __init__
|
66 |
+
|
67 |
+
print("Instantiating ModernBertForSentiment model structure...")
|
68 |
+
self.model = ModernBertForSentiment(bert_config)
|
69 |
+
|
70 |
+
print(f"Loading model weights from local checkpoint: {local_model_weights_path}")
|
71 |
+
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
|
72 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
73 |
+
model_state_to_load = checkpoint['model_state_dict']
|
74 |
+
else:
|
75 |
+
model_state_to_load = checkpoint # Assume it's the state_dict itself
|
76 |
+
self.model.load_state_dict(model_state_to_load)
|
77 |
+
print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.")
|
78 |
+
|
79 |
+
else: # Load from Hugging Face Hub
|
80 |
+
print(f"Attempting to load model from Hugging Face Hub: {model_hf_repo_id}...")
|
81 |
+
print(f"--- Debug: Entering HUGGING FACE HUB loading path ---") # Add this
|
82 |
+
print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") # Add this
|
83 |
+
if not model_hf_repo_id:
|
84 |
+
raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.")
|
85 |
+
|
86 |
+
print(f"Loading base ModernBertConfig from: {model_hf_repo_id}")
|
87 |
+
loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id)
|
88 |
+
|
89 |
+
# Augment loaded_config
|
90 |
+
loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
|
91 |
+
loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
|
92 |
+
loaded_config.classifier_dropout = model_yaml_cfg.get('dropout')
|
93 |
+
loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1)
|
94 |
+
|
95 |
+
print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
|
96 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
97 |
+
model_hf_repo_id,
|
98 |
+
config=loaded_config,
|
99 |
+
trust_remote_code=True
|
100 |
+
)
|
101 |
+
print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
|
102 |
+
|
103 |
self.model.eval()
|
|
|
104 |
|
105 |
def predict(self, text: str) -> Dict[str, Any]:
|
106 |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
|