import torch import torch.nn as nn from transformers import AutoModel ################################################################################### # Erweiterte Regressorklasse: Ein gemeinsamer Encoder, aber mehrere unabhängige Köpfe class BertMultiHeadRegressor(nn.Module): """ Mehrkopf-Regression auf einem beliebigen HF-Encoder (BERT/RoBERTa/DeBERTa/ModernBERT). - Gemeinsamer Encoder - n unabhängige Regressionsköpfe (je 1 Wert) - Robustes Pooling (Pooler wenn vorhanden, sonst maskiertes Mean) - Partielles Unfreezen ab `unfreeze_from` """ def __init__(self, pretrained_model_name: str, n_heads: int = 8, unfreeze_from: int = 8, dropout: float = 0.1): super().__init__() # Beliebigen Encoder laden self.encoder = AutoModel.from_pretrained( pretrained_model_name, low_cpu_mem_usage=False # vermeidet accelerate-Abhängigkeit zur Init ) hidden_size = self.encoder.config.hidden_size # Erst alles einfrieren … for p in self.encoder.parameters(): p.requires_grad = False # … dann Layer ab `unfreeze_from` freigeben (falls vorhanden) # Die meisten Encoder haben `.encoder.layer` encoder_block = getattr(self.encoder, "encoder", None) layers = getattr(encoder_block, "layer", None) if layers is not None: for layer in layers[unfreeze_from:]: for p in layer.parameters(): p.requires_grad = True else: # Fallback: wenn kein klassisches Lagen-Array existiert, nichts tun pass self.dropout = nn.Dropout(dropout) self.heads = nn.ModuleList([nn.Linear(hidden_size, 1) for _ in range(n_heads)]) def _pool(self, outputs, attention_mask): """ Robustes Pooling: - Wenn pooler_output vorhanden: nutzen (BERT/RoBERTa) - Sonst: maskiertes Mean-Pooling über last_hidden_state (z. B. DeBERTaV3) """ pooler = getattr(outputs, "pooler_output", None) if pooler is not None: return pooler # [B, H] last_hidden = outputs.last_hidden_state # [B, T, H] mask = attention_mask.unsqueeze(-1).float() # [B, T, 1] summed = (last_hidden * mask).sum(dim=1) # [B, H] denom = mask.sum(dim=1).clamp(min=1e-6) # [B, 1] return summed / denom def forward(self, input_ids, attention_mask, token_type_ids=None): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids if token_type_ids is not None else None, return_dict=True ) pooled = self._pool(outputs, attention_mask) # [B, H] pooled = self.dropout(pooled) preds = [head(pooled) for head in self.heads] # n × [B, 1] return torch.cat(preds, dim=1) # [B, n_heads]