voxmenthe commited on
Commit
3aaca16
·
verified ·
1 Parent(s): 170787e

Delete src

Browse files
Files changed (3) hide show
  1. src/config.yaml +0 -46
  2. src/inference.py +0 -79
  3. src/models.py +0 -172
src/config.yaml DELETED
@@ -1,46 +0,0 @@
1
- model:
2
- name: "answerdotai/ModernBERT-base"
3
- loss_function:
4
- name: "SentimentWeightedLoss" # Options: "SentimentWeightedLoss", "SentimentFocalLoss"
5
- # Parameters for the chosen loss function.
6
- # For SentimentFocalLoss, common params are:
7
- # gamma_focal: 1.0 # (e.g., 2.0 for standard, -2.0 for reversed, 0 for none)
8
- # label_smoothing_epsilon: 0.05 # (e.g., 0.0 to 0.1)
9
- # For SentimentWeightedLoss, params is empty:
10
- params:
11
- gamma_focal: 1.0
12
- label_smoothing_epsilon: 0.05
13
- output_dir: "checkpoints"
14
- max_length: 880 # 256
15
- dropout: 0.1
16
- # --- Pooling Strategy --- #
17
- # Options: "cls", "mean", "cls_mean_concat", "weighted_layer", "cls_weighted_concat"
18
- # "cls" uses just the [CLS] token for classification
19
- # "mean" uses mean pooling over final hidden states for classification
20
- # "cls_mean_concat" uses both [CLS] and mean pooling over final hidden states for classification
21
- # "weighted_layer" uses a weighted combination of the final hidden states from the top N layers for classification
22
- # "cls_weighted_concat" uses a weighted combination of the final hidden states from the top N layers and the [CLS] token for classification
23
-
24
- pooling_strategy: "mean" # Current default, change as needed
25
-
26
- num_weighted_layers: 6 # Number of top BERT layers to use for 'weighted_layer' strategies (e.g., 1 to 12 for BERT-base)
27
-
28
- data:
29
- # No specific data paths needed as we use HF datasets at the moment
30
-
31
- training:
32
- epochs: 6
33
- batch_size: 16
34
- lr: 1e-5 # 1e-5 # 2.0e-5
35
- weight_decay_rate: 0.02 # 0.01
36
- resume_from_checkpoint: "" # "checkpoints/mean_epoch2_0.9361acc_0.9355f1.pt" # Path to checkpoint file, or empty to not resume
37
-
38
- inference:
39
- # Default path, can be overridden
40
- model_path: "checkpoints/mean_epoch5_0.9575acc_0.9575f1.pt"
41
- # Using the same max_length as training for consistency
42
- max_length: 880 # 256
43
-
44
-
45
- # "answerdotai/ModernBERT-base"
46
- # "answerdotai/ModernBERT-large"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/inference.py DELETED
@@ -1,79 +0,0 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- from src.models import ModernBertForSentiment
4
- from transformers import ModernBertConfig
5
- from typing import Dict, Any
6
- import yaml
7
- import os
8
-
9
-
10
- class SentimentInference:
11
- def __init__(self, config_path: str = "config.yaml"):
12
- """Load configuration and initialize model and tokenizer."""
13
- with open(config_path, 'r') as f:
14
- config = yaml.safe_load(f)
15
-
16
- model_cfg = config.get('model', {})
17
- inference_cfg = config.get('inference', {})
18
-
19
- # Path to the .pt model weights file
20
- model_weights_path = inference_cfg.get('model_path',
21
- os.path.join(model_cfg.get('output_dir', 'checkpoints'), 'best_model.pt'))
22
-
23
- # Base model name from config (e.g., 'answerdotai/ModernBERT-base')
24
- # This will be used for loading both tokenizer and base BERT config from Hugging Face Hub
25
- base_model_name = model_cfg.get('name', 'answerdotai/ModernBERT-base')
26
-
27
- self.max_length = inference_cfg.get('max_length', model_cfg.get('max_length', 256))
28
-
29
- # Load tokenizer from the base model name (e.g., from Hugging Face Hub)
30
- print(f"Loading tokenizer from: {base_model_name}")
31
- self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
32
-
33
- # Load base BERT config from the base model name
34
- print(f"Loading ModernBertConfig from: {base_model_name}")
35
- bert_config = ModernBertConfig.from_pretrained(base_model_name)
36
-
37
- # --- Apply any necessary overrides from your config to the loaded bert_config ---
38
- # For example, if your ModernBertForSentiment expects specific config values beyond the base BERT model.
39
- # Your current ModernBertForSentiment takes the entire config object, which might implicitly carry these.
40
- # However, explicitly setting them on bert_config loaded from HF is safer if they are architecturally relevant.
41
- bert_config.classifier_dropout = model_cfg.get('dropout', bert_config.classifier_dropout) # Example
42
- # Ensure num_labels is set if your inference model needs it (usually for HF pipeline, less so for manual predict)
43
- # bert_config.num_labels = model_cfg.get('num_labels', 1) # Typically 1 for binary sentiment regression-style output
44
-
45
- # It's also important that pooling_strategy and num_weighted_layers are set on the config object
46
- # that ModernBertForSentiment receives, as it uses these to build its layers.
47
- # These are usually fine-tuning specific, not part of the base HF config, so they should come from your model_cfg.
48
- bert_config.pooling_strategy = model_cfg.get('pooling_strategy', 'cls')
49
- bert_config.num_weighted_layers = model_cfg.get('num_weighted_layers', 4)
50
- bert_config.loss_function = model_cfg.get('loss_function', {'name': 'SentimentWeightedLoss', 'params': {}}) # Needed by model init
51
- # Ensure num_labels is explicitly set for the model's classifier head
52
- bert_config.num_labels = 1 # For sentiment (positive/negative) often treated as 1 logit output
53
-
54
- print("Instantiating ModernBertForSentiment model structure...")
55
- self.model = ModernBertForSentiment(bert_config)
56
-
57
- print(f"Loading model weights from local checkpoint: {model_weights_path}")
58
- # Load the entire checkpoint dictionary first
59
- checkpoint = torch.load(model_weights_path, map_location=torch.device('cpu'))
60
-
61
- # Extract the model_state_dict from the checkpoint
62
- # This handles the case where the checkpoint saves more than just the model weights (e.g., optimizer state, epoch)
63
- if 'model_state_dict' in checkpoint:
64
- model_state_to_load = checkpoint['model_state_dict']
65
- else:
66
- # If the checkpoint is just the state_dict itself (older format or different saving convention)
67
- model_state_to_load = checkpoint
68
-
69
- self.model.load_state_dict(model_state_to_load)
70
- self.model.eval()
71
- print("Model loaded successfully.")
72
-
73
- def predict(self, text: str) -> Dict[str, Any]:
74
- inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length)
75
- with torch.no_grad():
76
- outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
77
- logits = outputs["logits"]
78
- prob = torch.sigmoid(logits).item()
79
- return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models.py DELETED
@@ -1,172 +0,0 @@
1
- from transformers import ModernBertModel, ModernBertPreTrainedModel
2
- from transformers.modeling_outputs import SequenceClassifierOutput
3
- from torch import nn
4
- import torch
5
- from src.train_utils import SentimentWeightedLoss, SentimentFocalLoss
6
- import torch.nn.functional as F
7
-
8
- from src.classifiers import ClassifierHead, ConcatClassifierHead
9
-
10
-
11
- class ModernBertForSentiment(ModernBertPreTrainedModel):
12
- """ModernBERT encoder with a dynamically configurable classification head and pooling strategy."""
13
-
14
- def __init__(self, config):
15
- super().__init__(config)
16
- self.num_labels = config.num_labels
17
- self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True
18
-
19
- # Store pooling strategy from config
20
- self.pooling_strategy = getattr(config, 'pooling_strategy', 'cls') # Default to 'cls'
21
- self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4)
22
-
23
- if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states:
24
- # This check is more of an assertion; train.py should set output_hidden_states=True
25
- raise ValueError(
26
- "output_hidden_states must be True in BertConfig for weighted_layer pooling."
27
- )
28
-
29
- # Initialize weights for weighted layer pooling
30
- if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']:
31
- # num_weighted_layers specifies how many *top* layers of BERT to use.
32
- # If num_weighted_layers is e.g. 4, we use the last 4 layers.
33
- self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers)
34
-
35
- # Determine classifier input size and choose head
36
- classifier_input_size = config.hidden_size
37
- if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']:
38
- classifier_input_size = config.hidden_size * 2
39
-
40
- # Dropout for features fed into the classifier head
41
- classifier_dropout_prob = (
42
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
43
- )
44
- self.features_dropout = nn.Dropout(classifier_dropout_prob)
45
-
46
- # Select the appropriate classifier head based on input feature dimension
47
- if classifier_input_size == config.hidden_size:
48
- self.classifier = ClassifierHead(
49
- hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size
50
- num_labels=config.num_labels,
51
- dropout_prob=classifier_dropout_prob
52
- )
53
- elif classifier_input_size == config.hidden_size * 2:
54
- self.classifier = ConcatClassifierHead(
55
- input_size=config.hidden_size * 2,
56
- hidden_size=config.hidden_size, # Internal hidden size of the head
57
- num_labels=config.num_labels,
58
- dropout_prob=classifier_dropout_prob
59
- )
60
- else:
61
- # This case should ideally not be reached with current strategies
62
- raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}")
63
-
64
- # Initialize loss function based on config
65
- loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}})
66
- loss_name = loss_config.get('name', 'SentimentWeightedLoss')
67
- loss_params = loss_config.get('params', {})
68
-
69
- if loss_name == "SentimentWeightedLoss":
70
- self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments
71
- elif loss_name == "SentimentFocalLoss":
72
- # Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss
73
- # For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon'
74
- self.loss_fct = SentimentFocalLoss(**loss_params)
75
- else:
76
- raise ValueError(f"Unsupported loss function: {loss_name}")
77
-
78
- self.post_init() # Initialize weights and apply final processing
79
-
80
- def _mean_pool(self, last_hidden_state, attention_mask):
81
- if attention_mask is None:
82
- attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids
83
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
84
- sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
85
- sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
86
- return sum_embeddings / sum_mask
87
-
88
- def _weighted_layer_pool(self, all_hidden_states):
89
- # all_hidden_states includes embeddings + output of each layer.
90
- # We want the outputs of the last num_weighted_layers.
91
- # Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers)
92
- # num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1)
93
- layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0)
94
- # layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size)
95
-
96
- # Normalize weights to sum to 1 (softmax or simple division)
97
- normalized_weights = F.softmax(self.layer_weights, dim=-1)
98
-
99
- # Weighted sum across layers
100
- # Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1)
101
- weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1)
102
- weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0)
103
- # weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size)
104
-
105
- # Pool the result (e.g., take [CLS] token of this weighted sum)
106
- return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum
107
-
108
- def forward(
109
- self,
110
- input_ids=None,
111
- attention_mask=None,
112
- labels=None,
113
- lengths=None,
114
- return_dict=None,
115
- **kwargs
116
- ):
117
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
118
-
119
- bert_outputs = self.bert(
120
- input_ids,
121
- attention_mask=attention_mask,
122
- return_dict=return_dict,
123
- output_hidden_states=self.config.output_hidden_states # Controlled by train.py
124
- )
125
-
126
- last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state
127
- pooled_features = None
128
-
129
- if self.pooling_strategy == 'cls':
130
- pooled_features = last_hidden_state[:, 0] # CLS token
131
- elif self.pooling_strategy == 'mean':
132
- pooled_features = self._mean_pool(last_hidden_state, attention_mask)
133
- elif self.pooling_strategy == 'cls_mean_concat':
134
- cls_output = last_hidden_state[:, 0]
135
- mean_output = self._mean_pool(last_hidden_state, attention_mask)
136
- pooled_features = torch.cat((cls_output, mean_output), dim=1)
137
- elif self.pooling_strategy == 'weighted_layer':
138
- if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
139
- raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
140
- all_hidden_states = bert_outputs.hidden_states
141
- pooled_features = self._weighted_layer_pool(all_hidden_states)
142
- elif self.pooling_strategy == 'cls_weighted_concat':
143
- if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
144
- raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
145
- cls_output = last_hidden_state[:, 0]
146
- all_hidden_states = bert_outputs.hidden_states
147
- weighted_output = self._weighted_layer_pool(all_hidden_states)
148
- pooled_features = torch.cat((cls_output, weighted_output), dim=1)
149
- else:
150
- raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}")
151
-
152
- pooled_features = self.features_dropout(pooled_features)
153
- logits = self.classifier(pooled_features)
154
-
155
- loss = None
156
- if labels is not None:
157
- if lengths is None:
158
- raise ValueError("lengths must be provided when labels are specified for loss calculation.")
159
- loss = self.loss_fct(logits.squeeze(-1), labels, lengths)
160
-
161
- if not return_dict:
162
- # Ensure 'outputs' from BERT is appropriately handled. If it's a tuple:
163
- bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions)
164
- output = (logits,) + bert_model_outputs
165
- return ((loss,) + output) if loss is not None else output
166
-
167
- return SequenceClassifierOutput(
168
- loss=loss,
169
- logits=logits,
170
- hidden_states=bert_outputs.hidden_states,
171
- attentions=bert_outputs.attentions,
172
- )