Delete src
Browse files- src/config.yaml +0 -46
- src/inference.py +0 -79
- src/models.py +0 -172
src/config.yaml
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
name: "answerdotai/ModernBERT-base"
|
3 |
-
loss_function:
|
4 |
-
name: "SentimentWeightedLoss" # Options: "SentimentWeightedLoss", "SentimentFocalLoss"
|
5 |
-
# Parameters for the chosen loss function.
|
6 |
-
# For SentimentFocalLoss, common params are:
|
7 |
-
# gamma_focal: 1.0 # (e.g., 2.0 for standard, -2.0 for reversed, 0 for none)
|
8 |
-
# label_smoothing_epsilon: 0.05 # (e.g., 0.0 to 0.1)
|
9 |
-
# For SentimentWeightedLoss, params is empty:
|
10 |
-
params:
|
11 |
-
gamma_focal: 1.0
|
12 |
-
label_smoothing_epsilon: 0.05
|
13 |
-
output_dir: "checkpoints"
|
14 |
-
max_length: 880 # 256
|
15 |
-
dropout: 0.1
|
16 |
-
# --- Pooling Strategy --- #
|
17 |
-
# Options: "cls", "mean", "cls_mean_concat", "weighted_layer", "cls_weighted_concat"
|
18 |
-
# "cls" uses just the [CLS] token for classification
|
19 |
-
# "mean" uses mean pooling over final hidden states for classification
|
20 |
-
# "cls_mean_concat" uses both [CLS] and mean pooling over final hidden states for classification
|
21 |
-
# "weighted_layer" uses a weighted combination of the final hidden states from the top N layers for classification
|
22 |
-
# "cls_weighted_concat" uses a weighted combination of the final hidden states from the top N layers and the [CLS] token for classification
|
23 |
-
|
24 |
-
pooling_strategy: "mean" # Current default, change as needed
|
25 |
-
|
26 |
-
num_weighted_layers: 6 # Number of top BERT layers to use for 'weighted_layer' strategies (e.g., 1 to 12 for BERT-base)
|
27 |
-
|
28 |
-
data:
|
29 |
-
# No specific data paths needed as we use HF datasets at the moment
|
30 |
-
|
31 |
-
training:
|
32 |
-
epochs: 6
|
33 |
-
batch_size: 16
|
34 |
-
lr: 1e-5 # 1e-5 # 2.0e-5
|
35 |
-
weight_decay_rate: 0.02 # 0.01
|
36 |
-
resume_from_checkpoint: "" # "checkpoints/mean_epoch2_0.9361acc_0.9355f1.pt" # Path to checkpoint file, or empty to not resume
|
37 |
-
|
38 |
-
inference:
|
39 |
-
# Default path, can be overridden
|
40 |
-
model_path: "checkpoints/mean_epoch5_0.9575acc_0.9575f1.pt"
|
41 |
-
# Using the same max_length as training for consistency
|
42 |
-
max_length: 880 # 256
|
43 |
-
|
44 |
-
|
45 |
-
# "answerdotai/ModernBERT-base"
|
46 |
-
# "answerdotai/ModernBERT-large"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/inference.py
DELETED
@@ -1,79 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
-
from src.models import ModernBertForSentiment
|
4 |
-
from transformers import ModernBertConfig
|
5 |
-
from typing import Dict, Any
|
6 |
-
import yaml
|
7 |
-
import os
|
8 |
-
|
9 |
-
|
10 |
-
class SentimentInference:
|
11 |
-
def __init__(self, config_path: str = "config.yaml"):
|
12 |
-
"""Load configuration and initialize model and tokenizer."""
|
13 |
-
with open(config_path, 'r') as f:
|
14 |
-
config = yaml.safe_load(f)
|
15 |
-
|
16 |
-
model_cfg = config.get('model', {})
|
17 |
-
inference_cfg = config.get('inference', {})
|
18 |
-
|
19 |
-
# Path to the .pt model weights file
|
20 |
-
model_weights_path = inference_cfg.get('model_path',
|
21 |
-
os.path.join(model_cfg.get('output_dir', 'checkpoints'), 'best_model.pt'))
|
22 |
-
|
23 |
-
# Base model name from config (e.g., 'answerdotai/ModernBERT-base')
|
24 |
-
# This will be used for loading both tokenizer and base BERT config from Hugging Face Hub
|
25 |
-
base_model_name = model_cfg.get('name', 'answerdotai/ModernBERT-base')
|
26 |
-
|
27 |
-
self.max_length = inference_cfg.get('max_length', model_cfg.get('max_length', 256))
|
28 |
-
|
29 |
-
# Load tokenizer from the base model name (e.g., from Hugging Face Hub)
|
30 |
-
print(f"Loading tokenizer from: {base_model_name}")
|
31 |
-
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
32 |
-
|
33 |
-
# Load base BERT config from the base model name
|
34 |
-
print(f"Loading ModernBertConfig from: {base_model_name}")
|
35 |
-
bert_config = ModernBertConfig.from_pretrained(base_model_name)
|
36 |
-
|
37 |
-
# --- Apply any necessary overrides from your config to the loaded bert_config ---
|
38 |
-
# For example, if your ModernBertForSentiment expects specific config values beyond the base BERT model.
|
39 |
-
# Your current ModernBertForSentiment takes the entire config object, which might implicitly carry these.
|
40 |
-
# However, explicitly setting them on bert_config loaded from HF is safer if they are architecturally relevant.
|
41 |
-
bert_config.classifier_dropout = model_cfg.get('dropout', bert_config.classifier_dropout) # Example
|
42 |
-
# Ensure num_labels is set if your inference model needs it (usually for HF pipeline, less so for manual predict)
|
43 |
-
# bert_config.num_labels = model_cfg.get('num_labels', 1) # Typically 1 for binary sentiment regression-style output
|
44 |
-
|
45 |
-
# It's also important that pooling_strategy and num_weighted_layers are set on the config object
|
46 |
-
# that ModernBertForSentiment receives, as it uses these to build its layers.
|
47 |
-
# These are usually fine-tuning specific, not part of the base HF config, so they should come from your model_cfg.
|
48 |
-
bert_config.pooling_strategy = model_cfg.get('pooling_strategy', 'cls')
|
49 |
-
bert_config.num_weighted_layers = model_cfg.get('num_weighted_layers', 4)
|
50 |
-
bert_config.loss_function = model_cfg.get('loss_function', {'name': 'SentimentWeightedLoss', 'params': {}}) # Needed by model init
|
51 |
-
# Ensure num_labels is explicitly set for the model's classifier head
|
52 |
-
bert_config.num_labels = 1 # For sentiment (positive/negative) often treated as 1 logit output
|
53 |
-
|
54 |
-
print("Instantiating ModernBertForSentiment model structure...")
|
55 |
-
self.model = ModernBertForSentiment(bert_config)
|
56 |
-
|
57 |
-
print(f"Loading model weights from local checkpoint: {model_weights_path}")
|
58 |
-
# Load the entire checkpoint dictionary first
|
59 |
-
checkpoint = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
60 |
-
|
61 |
-
# Extract the model_state_dict from the checkpoint
|
62 |
-
# This handles the case where the checkpoint saves more than just the model weights (e.g., optimizer state, epoch)
|
63 |
-
if 'model_state_dict' in checkpoint:
|
64 |
-
model_state_to_load = checkpoint['model_state_dict']
|
65 |
-
else:
|
66 |
-
# If the checkpoint is just the state_dict itself (older format or different saving convention)
|
67 |
-
model_state_to_load = checkpoint
|
68 |
-
|
69 |
-
self.model.load_state_dict(model_state_to_load)
|
70 |
-
self.model.eval()
|
71 |
-
print("Model loaded successfully.")
|
72 |
-
|
73 |
-
def predict(self, text: str) -> Dict[str, Any]:
|
74 |
-
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length)
|
75 |
-
with torch.no_grad():
|
76 |
-
outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
77 |
-
logits = outputs["logits"]
|
78 |
-
prob = torch.sigmoid(logits).item()
|
79 |
-
return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models.py
DELETED
@@ -1,172 +0,0 @@
|
|
1 |
-
from transformers import ModernBertModel, ModernBertPreTrainedModel
|
2 |
-
from transformers.modeling_outputs import SequenceClassifierOutput
|
3 |
-
from torch import nn
|
4 |
-
import torch
|
5 |
-
from src.train_utils import SentimentWeightedLoss, SentimentFocalLoss
|
6 |
-
import torch.nn.functional as F
|
7 |
-
|
8 |
-
from src.classifiers import ClassifierHead, ConcatClassifierHead
|
9 |
-
|
10 |
-
|
11 |
-
class ModernBertForSentiment(ModernBertPreTrainedModel):
|
12 |
-
"""ModernBERT encoder with a dynamically configurable classification head and pooling strategy."""
|
13 |
-
|
14 |
-
def __init__(self, config):
|
15 |
-
super().__init__(config)
|
16 |
-
self.num_labels = config.num_labels
|
17 |
-
self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True
|
18 |
-
|
19 |
-
# Store pooling strategy from config
|
20 |
-
self.pooling_strategy = getattr(config, 'pooling_strategy', 'cls') # Default to 'cls'
|
21 |
-
self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4)
|
22 |
-
|
23 |
-
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states:
|
24 |
-
# This check is more of an assertion; train.py should set output_hidden_states=True
|
25 |
-
raise ValueError(
|
26 |
-
"output_hidden_states must be True in BertConfig for weighted_layer pooling."
|
27 |
-
)
|
28 |
-
|
29 |
-
# Initialize weights for weighted layer pooling
|
30 |
-
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']:
|
31 |
-
# num_weighted_layers specifies how many *top* layers of BERT to use.
|
32 |
-
# If num_weighted_layers is e.g. 4, we use the last 4 layers.
|
33 |
-
self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers)
|
34 |
-
|
35 |
-
# Determine classifier input size and choose head
|
36 |
-
classifier_input_size = config.hidden_size
|
37 |
-
if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']:
|
38 |
-
classifier_input_size = config.hidden_size * 2
|
39 |
-
|
40 |
-
# Dropout for features fed into the classifier head
|
41 |
-
classifier_dropout_prob = (
|
42 |
-
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
43 |
-
)
|
44 |
-
self.features_dropout = nn.Dropout(classifier_dropout_prob)
|
45 |
-
|
46 |
-
# Select the appropriate classifier head based on input feature dimension
|
47 |
-
if classifier_input_size == config.hidden_size:
|
48 |
-
self.classifier = ClassifierHead(
|
49 |
-
hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size
|
50 |
-
num_labels=config.num_labels,
|
51 |
-
dropout_prob=classifier_dropout_prob
|
52 |
-
)
|
53 |
-
elif classifier_input_size == config.hidden_size * 2:
|
54 |
-
self.classifier = ConcatClassifierHead(
|
55 |
-
input_size=config.hidden_size * 2,
|
56 |
-
hidden_size=config.hidden_size, # Internal hidden size of the head
|
57 |
-
num_labels=config.num_labels,
|
58 |
-
dropout_prob=classifier_dropout_prob
|
59 |
-
)
|
60 |
-
else:
|
61 |
-
# This case should ideally not be reached with current strategies
|
62 |
-
raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}")
|
63 |
-
|
64 |
-
# Initialize loss function based on config
|
65 |
-
loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}})
|
66 |
-
loss_name = loss_config.get('name', 'SentimentWeightedLoss')
|
67 |
-
loss_params = loss_config.get('params', {})
|
68 |
-
|
69 |
-
if loss_name == "SentimentWeightedLoss":
|
70 |
-
self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments
|
71 |
-
elif loss_name == "SentimentFocalLoss":
|
72 |
-
# Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss
|
73 |
-
# For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon'
|
74 |
-
self.loss_fct = SentimentFocalLoss(**loss_params)
|
75 |
-
else:
|
76 |
-
raise ValueError(f"Unsupported loss function: {loss_name}")
|
77 |
-
|
78 |
-
self.post_init() # Initialize weights and apply final processing
|
79 |
-
|
80 |
-
def _mean_pool(self, last_hidden_state, attention_mask):
|
81 |
-
if attention_mask is None:
|
82 |
-
attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids
|
83 |
-
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
84 |
-
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
|
85 |
-
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
86 |
-
return sum_embeddings / sum_mask
|
87 |
-
|
88 |
-
def _weighted_layer_pool(self, all_hidden_states):
|
89 |
-
# all_hidden_states includes embeddings + output of each layer.
|
90 |
-
# We want the outputs of the last num_weighted_layers.
|
91 |
-
# Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers)
|
92 |
-
# num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1)
|
93 |
-
layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0)
|
94 |
-
# layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size)
|
95 |
-
|
96 |
-
# Normalize weights to sum to 1 (softmax or simple division)
|
97 |
-
normalized_weights = F.softmax(self.layer_weights, dim=-1)
|
98 |
-
|
99 |
-
# Weighted sum across layers
|
100 |
-
# Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1)
|
101 |
-
weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1)
|
102 |
-
weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0)
|
103 |
-
# weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size)
|
104 |
-
|
105 |
-
# Pool the result (e.g., take [CLS] token of this weighted sum)
|
106 |
-
return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum
|
107 |
-
|
108 |
-
def forward(
|
109 |
-
self,
|
110 |
-
input_ids=None,
|
111 |
-
attention_mask=None,
|
112 |
-
labels=None,
|
113 |
-
lengths=None,
|
114 |
-
return_dict=None,
|
115 |
-
**kwargs
|
116 |
-
):
|
117 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
118 |
-
|
119 |
-
bert_outputs = self.bert(
|
120 |
-
input_ids,
|
121 |
-
attention_mask=attention_mask,
|
122 |
-
return_dict=return_dict,
|
123 |
-
output_hidden_states=self.config.output_hidden_states # Controlled by train.py
|
124 |
-
)
|
125 |
-
|
126 |
-
last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state
|
127 |
-
pooled_features = None
|
128 |
-
|
129 |
-
if self.pooling_strategy == 'cls':
|
130 |
-
pooled_features = last_hidden_state[:, 0] # CLS token
|
131 |
-
elif self.pooling_strategy == 'mean':
|
132 |
-
pooled_features = self._mean_pool(last_hidden_state, attention_mask)
|
133 |
-
elif self.pooling_strategy == 'cls_mean_concat':
|
134 |
-
cls_output = last_hidden_state[:, 0]
|
135 |
-
mean_output = self._mean_pool(last_hidden_state, attention_mask)
|
136 |
-
pooled_features = torch.cat((cls_output, mean_output), dim=1)
|
137 |
-
elif self.pooling_strategy == 'weighted_layer':
|
138 |
-
if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
|
139 |
-
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
|
140 |
-
all_hidden_states = bert_outputs.hidden_states
|
141 |
-
pooled_features = self._weighted_layer_pool(all_hidden_states)
|
142 |
-
elif self.pooling_strategy == 'cls_weighted_concat':
|
143 |
-
if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
|
144 |
-
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
|
145 |
-
cls_output = last_hidden_state[:, 0]
|
146 |
-
all_hidden_states = bert_outputs.hidden_states
|
147 |
-
weighted_output = self._weighted_layer_pool(all_hidden_states)
|
148 |
-
pooled_features = torch.cat((cls_output, weighted_output), dim=1)
|
149 |
-
else:
|
150 |
-
raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}")
|
151 |
-
|
152 |
-
pooled_features = self.features_dropout(pooled_features)
|
153 |
-
logits = self.classifier(pooled_features)
|
154 |
-
|
155 |
-
loss = None
|
156 |
-
if labels is not None:
|
157 |
-
if lengths is None:
|
158 |
-
raise ValueError("lengths must be provided when labels are specified for loss calculation.")
|
159 |
-
loss = self.loss_fct(logits.squeeze(-1), labels, lengths)
|
160 |
-
|
161 |
-
if not return_dict:
|
162 |
-
# Ensure 'outputs' from BERT is appropriately handled. If it's a tuple:
|
163 |
-
bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions)
|
164 |
-
output = (logits,) + bert_model_outputs
|
165 |
-
return ((loss,) + output) if loss is not None else output
|
166 |
-
|
167 |
-
return SequenceClassifierOutput(
|
168 |
-
loss=loss,
|
169 |
-
logits=logits,
|
170 |
-
hidden_states=bert_outputs.hidden_states,
|
171 |
-
attentions=bert_outputs.attentions,
|
172 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|