|
from transformers import ModernBertModel, ModernBertPreTrainedModel |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from torch import nn |
|
import torch |
|
from train_utils import SentimentWeightedLoss, SentimentFocalLoss |
|
import torch.nn.functional as F |
|
|
|
from classifiers import ClassifierHead, ConcatClassifierHead |
|
|
|
|
|
class ModernBertForSentiment(ModernBertPreTrainedModel): |
|
"""ModernBERT encoder with a dynamically configurable classification head and pooling strategy.""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.bert = ModernBertModel(config) |
|
|
|
|
|
self.pooling_strategy = getattr(config, 'pooling_strategy', 'mean') |
|
self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4) |
|
|
|
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states: |
|
|
|
raise ValueError( |
|
"output_hidden_states must be True in BertConfig for weighted_layer pooling." |
|
) |
|
|
|
|
|
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']: |
|
|
|
|
|
self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers) |
|
|
|
|
|
classifier_input_size = config.hidden_size |
|
if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']: |
|
classifier_input_size = config.hidden_size * 2 |
|
|
|
|
|
classifier_dropout_prob = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.features_dropout = nn.Dropout(classifier_dropout_prob) |
|
|
|
|
|
if classifier_input_size == config.hidden_size: |
|
self.classifier = ClassifierHead( |
|
hidden_size=config.hidden_size, |
|
num_labels=config.num_labels, |
|
dropout_prob=classifier_dropout_prob |
|
) |
|
elif classifier_input_size == config.hidden_size * 2: |
|
self.classifier = ConcatClassifierHead( |
|
input_size=config.hidden_size * 2, |
|
hidden_size=config.hidden_size, |
|
num_labels=config.num_labels, |
|
dropout_prob=classifier_dropout_prob |
|
) |
|
else: |
|
|
|
raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}") |
|
|
|
|
|
loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}}) |
|
loss_name = loss_config.get('name', 'SentimentWeightedLoss') |
|
loss_params = loss_config.get('params', {}) |
|
|
|
if loss_name == "SentimentWeightedLoss": |
|
self.loss_fct = SentimentWeightedLoss() |
|
elif loss_name == "SentimentFocalLoss": |
|
|
|
|
|
self.loss_fct = SentimentFocalLoss(**loss_params) |
|
else: |
|
raise ValueError(f"Unsupported loss function: {loss_name}") |
|
|
|
self.post_init() |
|
|
|
def _mean_pool(self, last_hidden_state, attention_mask): |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
|
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) |
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
return sum_embeddings / sum_mask |
|
|
|
def _weighted_layer_pool(self, all_hidden_states): |
|
|
|
|
|
|
|
|
|
layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0) |
|
|
|
|
|
|
|
normalized_weights = F.softmax(self.layer_weights, dim=-1) |
|
|
|
|
|
|
|
weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1) |
|
weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0) |
|
|
|
|
|
|
|
return weighted_sum_hidden_states[:, 0] |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
labels=None, |
|
lengths=None, |
|
return_dict=None, |
|
**kwargs |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
bert_outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=return_dict, |
|
output_hidden_states=self.config.output_hidden_states |
|
) |
|
|
|
last_hidden_state = bert_outputs[0] |
|
pooled_features = None |
|
|
|
if self.pooling_strategy == 'cls': |
|
pooled_features = last_hidden_state[:, 0] |
|
elif self.pooling_strategy == 'mean': |
|
pooled_features = self._mean_pool(last_hidden_state, attention_mask) |
|
elif self.pooling_strategy == 'cls_mean_concat': |
|
cls_output = last_hidden_state[:, 0] |
|
mean_output = self._mean_pool(last_hidden_state, attention_mask) |
|
pooled_features = torch.cat((cls_output, mean_output), dim=1) |
|
elif self.pooling_strategy == 'weighted_layer': |
|
if not self.config.output_hidden_states or bert_outputs.hidden_states is None: |
|
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.") |
|
all_hidden_states = bert_outputs.hidden_states |
|
pooled_features = self._weighted_layer_pool(all_hidden_states) |
|
elif self.pooling_strategy == 'cls_weighted_concat': |
|
if not self.config.output_hidden_states or bert_outputs.hidden_states is None: |
|
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.") |
|
cls_output = last_hidden_state[:, 0] |
|
all_hidden_states = bert_outputs.hidden_states |
|
weighted_output = self._weighted_layer_pool(all_hidden_states) |
|
pooled_features = torch.cat((cls_output, weighted_output), dim=1) |
|
else: |
|
raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}") |
|
|
|
pooled_features = self.features_dropout(pooled_features) |
|
logits = self.classifier(pooled_features) |
|
|
|
loss = None |
|
if labels is not None: |
|
if lengths is None: |
|
raise ValueError("lengths must be provided when labels are specified for loss calculation.") |
|
loss = self.loss_fct(logits.squeeze(-1), labels, lengths) |
|
|
|
if not return_dict: |
|
|
|
bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions) |
|
output = (logits,) + bert_model_outputs |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=bert_outputs.hidden_states, |
|
attentions=bert_outputs.attentions, |
|
) |