import torch import torch.nn as nn from transformers import DebertaV2Model class MultiTaskBiasModel(nn.Module): def __init__(self, model_name='microsoft/deberta-v3-base'): super().__init__() self.bert = DebertaV2Model.from_pretrained(model_name) hidden = self.bert.config.hidden_size self.heads = nn.ModuleDict({ task: nn.Sequential( nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden, 3) ) for task in ['political', 'gender', 'immigration'] }) def forward(self, input_ids, attention_mask, tasks): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] logits = [] for i in range(len(tasks)): logits.append(self.heads[tasks[i]](outputs[i].unsqueeze(0))) return torch.cat(logits, dim=0)